From e85411a5fd68359e63347f42046d095a2f6bdbc1 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:37:59 +0530 Subject: [PATCH 001/138] Updated Internal CI (#581) --- .github/workflows/internal_ci.yml | 49 +++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 .github/workflows/internal_ci.yml diff --git a/.github/workflows/internal_ci.yml b/.github/workflows/internal_ci.yml new file mode 100644 index 0000000000000..3d0a8d8bcaf42 --- /dev/null +++ b/.github/workflows/internal_ci.yml @@ -0,0 +1,49 @@ +name : Internal CI + +on: + pull_request_target: + branches: + - '**' # Triggers on a PR to any Branch + +permissions: + contents: read + pull-requests: read + +jobs: + build: + + if: github.event.pull_request.draft == false + runs-on: [self-hosted, Linux, X64] # Runs on a Lunar lake + env: + BUILD_SOURCESDIRECTORY: ${{ github.workspace }} + BUILD_BINARIESDIRECTORY: ${{ github.workspace }}/build + + steps: + - name: Check PR Author Authorization + run: | + if [[ "${{ github.event.pull_request.head.repo.full_name }}" != "${{ github.repository }}" ]]; then + echo "PR is from a fork: ${{ github.event.pull_request.head.repo.full_name }}" + fi + + - name: Checkout PR Branch + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + fetch-depth: 1 # checkout the pr branch + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Create build directory + run: | + mkdir -p ${{ env.BUILD_BINARIESDIRECTORY }} + chmod -R 777 ${{ env.BUILD_BINARIESDIRECTORY }} + + - name: Running Internal CI # Trigger Internal CI on the pr branch + run: | + cd tools/ci_build/github/linux/ + dir + ./run_dockerbuild.sh -o ubuntu22.04 -p 3.10 -d openvino -v 2024.5.0 -x "--config Release --use_openvino CPU --build_wheel --build_shared_lib --parallel " From 0d42af9c45437f57aaff38719e509ef869fbbfdb Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Tue, 25 Feb 2025 20:23:03 +0530 Subject: [PATCH 002/138] Updated Internal CI OV version (#594) --- .github/workflows/internal_ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/internal_ci.yml b/.github/workflows/internal_ci.yml index 3d0a8d8bcaf42..6ece42bb90571 100644 --- a/.github/workflows/internal_ci.yml +++ b/.github/workflows/internal_ci.yml @@ -46,4 +46,4 @@ jobs: run: | cd tools/ci_build/github/linux/ dir - ./run_dockerbuild.sh -o ubuntu22.04 -p 3.10 -d openvino -v 2024.5.0 -x "--config Release --use_openvino CPU --build_wheel --build_shared_lib --parallel " + ./run_dockerbuild.sh -o ubuntu22.04 -p 3.10 -d openvino -v 2025.0.0 -x "--config Release --use_openvino CPU --build_wheel --build_shared_lib --parallel " From 3dc24efd7a4b88d67478e89c4f797095374335ad Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Thu, 27 Feb 2025 14:08:10 +0530 Subject: [PATCH 003/138] Updated ov version in pipeline (#595) --- .../github/azure-pipelines/linux-openvino-ci-pipeline.yml | 2 +- .../github/linux/docker/Dockerfile.ubuntu_openvino | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index c7b814f3dd52c..da333774b8496 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -33,5 +33,5 @@ jobs: parameters: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' - RunDockerBuildArgs: '-o ubuntu22.04 -p 3.10 -d openvino -v 2024.5.0 -x "--use_openvino CPU --build_wheel"' + RunDockerBuildArgs: '-o ubuntu22.04 -p 3.10 -d openvino -v 2025.0.0 -x "--use_openvino CPU --build_wheel"' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino index 7b1e3fa677375..b53a2302be403 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino @@ -1,7 +1,7 @@ ARG UBUNTU_VERSION=22.04 FROM ubuntu:${UBUNTU_VERSION} -ARG OPENVINO_VERSION=2024.5.0 +ARG OPENVINO_VERSION=2025.0.0 ARG PYTHON_VERSION=3.10 ADD scripts /tmp/scripts @@ -19,9 +19,9 @@ ENV IE_PLUGINS_PATH=$INTEL_OPENVINO_DIR/runtime/lib/intel64 ENV DEBIAN_FRONTEND=noninteractive RUN cd /opt && mkdir -p intel && cd intel && \ - wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.5/linux/l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64.tgz && \ - tar xzf l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64.tgz && \ - mv l_openvino_toolkit_ubuntu22_2024.5.0.17288.7975fa5da0c_x86_64 openvino_2024.5.0 && \ + wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.0/linux/openvino_toolkit_ubuntu22_2025.0.0.17942.1f68be9f594_x86_64.tgz && \ + tar xzf openvino_toolkit_ubuntu22_2025.0.0.17942.1f68be9f594_x86_64.tgz && rm -rf openvino_toolkit_ubuntu22_2025.0.0.17942.1f68be9f594_x86_64.tgz && \ + mv openvino_toolkit_ubuntu22_2025.0.0.17942.1f68be9f594_x86_64 openvino_2025.0.0 && \ cd $INTEL_OPENVINO_DIR/install_dependencies && ./install_openvino_dependencies.sh -y WORKDIR /root From 9c2fee5a1d54e142a1e540b9f0e261f5d2827b5d Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Fri, 28 Feb 2025 15:15:30 +0530 Subject: [PATCH 004/138] [OVEP] Fix for deprecated OV element type (#597) --- onnxruntime/core/providers/openvino/backends/basic_backend.cc | 2 +- .../custom_op_openvino_wrapper_library/openvino_wrapper.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 3ac4d22f5453c..2e5fbf208e924 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -167,7 +167,7 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { if (session_context_.precision.find("ACCURACY") != std::string::npos && session_context_.device_type.find("GPU") != std::string::npos) { if (session_context_.OpenVINO_Version.at(0) >= 2024) { - device_config.emplace(ov::hint::inference_precision(ov::element::undefined)); + device_config.emplace(ov::hint::inference_precision(ov::element::dynamic)); device_config.emplace(ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY)); } else { if (!subgraph_context_.model_precision.empty()) diff --git a/onnxruntime/test/testdata/custom_op_openvino_wrapper_library/openvino_wrapper.cc b/onnxruntime/test/testdata/custom_op_openvino_wrapper_library/openvino_wrapper.cc index 27d5c59439243..d4ce3320e13ca 100644 --- a/onnxruntime/test/testdata/custom_op_openvino_wrapper_library/openvino_wrapper.cc +++ b/onnxruntime/test/testdata/custom_op_openvino_wrapper_library/openvino_wrapper.cc @@ -35,7 +35,7 @@ static ov::element::Type ConvertONNXToOVType(ONNXTensorElementDataType onnx_type case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return ov::element::bf16; default: - return ov::element::undefined; + return ov::element::dynamic; } } From 60ee27ad5928ce66753303b464551526d41ce792 Mon Sep 17 00:00:00 2001 From: sfatimar Date: Fri, 28 Feb 2025 20:45:44 +0530 Subject: [PATCH 005/138] Sahar/session option develop (#601) Changes to make sure to honor SessionOptions API Contract --- .../core/providers/openvino/contexts.h | 1 + .../openvino/openvino_provider_factory.cc | 52 ++++++++++++------- .../core/session/provider_bridge_ort.cc | 12 +++-- 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 216fc5b132696..a1a756a9baef7 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -101,6 +101,7 @@ struct ProviderInfo { bool so_context_embed_mode{false}; // ORT session option bool so_share_ep_contexts{false}; // ORT session option fs::path so_context_file_path{}; // ORT session option + const ConfigOptions* config_options{NULL}; }; // Holds context applicable to the entire EP instance. diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 95e039f8b6d5f..c4fe16e035241 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -14,12 +14,22 @@ namespace onnxruntime { namespace openvino_ep { -void ParseConfigOptions(ProviderInfo& pi, const ConfigOptions& config_options) { - pi.so_disable_cpu_ep_fallback = config_options.GetConfigOrDefault(kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; - pi.so_context_enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; - pi.so_context_embed_mode = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; - pi.so_share_ep_contexts = config_options.GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; - pi.so_context_file_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); +void ParseConfigOptions(ProviderInfo& pi) { + if(pi.config_options==NULL) + return; + + pi.so_disable_cpu_ep_fallback = pi.config_options->GetConfigOrDefault(kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; + pi.so_context_enable = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + pi.so_context_embed_mode = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; + pi.so_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; + pi.so_context_file_path = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + + if (pi.so_share_ep_contexts) { + ov::AnyMap map; + map["NPU_COMPILATION_MODE_PARAMS"] = "enable-wd-blockarg-input=true compute-layers-with-higher-precision=Sqrt,Power,ReduceSum"; + pi.load_config["NPU"] = std::move(map); + } + } void* ParseUint64(const ProviderOptions& provider_options, std::string option_name) { @@ -166,6 +176,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { ~OpenVINOProviderFactory() override {} std::unique_ptr CreateProvider() override { + ParseConfigOptions(provider_info_); return std::make_unique(provider_info_, shared_context_); } @@ -184,13 +195,23 @@ struct OpenVINO_Provider : Provider { void* GetInfo() override { return &info_; } std::shared_ptr CreateExecutionProviderFactory(const void* void_params) override { - // Extract the void_params into ProviderOptions and ConfigOptions - using ConfigBuffer = std::pair; - const ConfigBuffer* buffer = reinterpret_cast(void_params); - const auto& provider_options = *buffer->first; - const auto& config_options = buffer->second; + if (void_params == nullptr) { + LOGS_DEFAULT(ERROR) << "[OpenVINO EP] Passed NULL options to CreateExecutionProviderFactory()"; + return nullptr; + } + + std::array pointers_array = *reinterpret_cast*>(void_params); + const ProviderOptions* provider_options_ptr = reinterpret_cast(pointers_array[0]); + const ConfigOptions* config_options = reinterpret_cast(pointers_array[1]); + + if(provider_options_ptr == NULL) { + LOGS_DEFAULT(ERROR) << "[OpenVINO EP] Passed NULL ProviderOptions to CreateExecutionProviderFactory()"; + return nullptr; + } + const ProviderOptions provider_options = *provider_options_ptr; ProviderInfo pi; + pi.config_options = config_options; std::string bool_flag = ""; @@ -326,20 +347,11 @@ struct OpenVINO_Provider : Provider { pi.disable_dynamic_shapes = ParseBooleanOption(provider_options, "disable_dynamic_shapes"); - ParseConfigOptions(pi, config_options); - // Always true for NPU plugin or when passed . if (pi.device_type.find("NPU") != std::string::npos) { pi.disable_dynamic_shapes = true; } - // Append values to config to support weight-as-inputs conversion for shared contexts - if (pi.so_share_ep_contexts) { - ov::AnyMap map; - map["NPU_COMPILATION_MODE_PARAMS"] = "enable-wd-blockarg-input=true compute-layers-with-higher-precision=Sqrt,Power,ReduceSum"; - pi.load_config["NPU"] = std::move(map); - } - return std::make_shared(pi, SharedContext::Get()); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 1d25ceb9af8a3..77c6d4c371f69 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1998,10 +1998,14 @@ std::shared_ptr QNNProviderFactoryCreator::Create(con std::shared_ptr OpenVINOProviderFactoryCreator::Create( const ProviderOptions* provider_options_map, const SessionOptions* session_options) { // Append session options applicable for EP to EP Provider options. - std::pair config_buffer = {provider_options_map, - session_options->config_options}; - const void* obj = reinterpret_cast(&config_buffer); - return s_library_openvino.Get().CreateExecutionProviderFactory(obj); + const ConfigOptions* config_options = nullptr; + if (session_options !=nullptr) { + config_options = &session_options->config_options; + } + + std::array configs_array = {provider_options_map, config_options}; + const void* arg = reinterpret_cast(&configs_array); + return s_library_openvino.Get().CreateExecutionProviderFactory(arg); } std::shared_ptr DnnlProviderFactoryCreator::Create(const OrtDnnlProviderOptions* dnnl_options) { From ec62bf33c4070c360b9434f1ced5098fb475b7a7 Mon Sep 17 00:00:00 2001 From: Jaskaran Singh Nagi Date: Mon, 3 Mar 2025 19:35:10 -0800 Subject: [PATCH 006/138] Use absolute paths for libraries loaded with LOAD_WITH_ALTERED_SEARCH_PATH' (#602) --- onnxruntime/test/python/onnxruntime_test_python.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 3af6e8ccacfb8..f3ebc92409f77 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1307,7 +1307,7 @@ def test_session_options_add_external_initializers(self): def test_register_custom_ops_library(self): if sys.platform.startswith("win"): - shared_library = "custom_op_library.dll" + shared_library = os.path.abspath("custom_op_library.dll") if not os.path.exists(shared_library): raise FileNotFoundError(f"Unable to find '{shared_library}'") @@ -1724,7 +1724,7 @@ def test_register_custom_e_ps_library(self): return if sys.platform.startswith("win"): - shared_library = "test_execution_provider.dll" + shared_library = os.path.abspath("test_execution_provider.dll") elif sys.platform.startswith("darwin"): # exclude for macos From bd32f5140eb29980a2d8705ef34f3e3c4cb6365e Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Tue, 4 Mar 2025 06:50:55 -0800 Subject: [PATCH 007/138] Remove unintended model copies during compilation (#584) Co-authored-by: sfatimar --- .../core/providers/openvino/backend_utils.cc | 4 ++-- .../core/providers/openvino/backend_utils.h | 2 +- .../openvino/backends/basic_backend.cc | 22 +++++++------------ .../core/providers/openvino/ov_interface.cc | 4 ++-- .../core/providers/openvino/ov_interface.h | 2 +- 5 files changed, 14 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 255154b8788ad..2ee5e9ec3e3a9 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -137,14 +137,14 @@ bool IsCILogEnabled() { } std::shared_ptr -CreateOVModel(const std::string model, +CreateOVModel(std::string&& model, const SessionContext& session_context, std::map>& const_outputs_map) { if (IsCILogEnabled()) { std::cout << "CreateNgraphFunc" << std::endl; } try { - auto ov_model = OVCore::Get()->ReadModel(model, session_context.onnx_model_path_name.string()); + auto ov_model = OVCore::Get()->ReadModel(std::move(model), session_context.onnx_model_path_name.string()); // Check for Constant Folding if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) { diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index a4e6fc0828f79..f13b1b05ced67 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -62,7 +62,7 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, size_t batch_slice_idx); std::shared_ptr -CreateOVModel(const std::string model, +CreateOVModel(std::string&& model, const SessionContext& session_context, std::map>& const_outputs_map); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 2e5fbf208e924..9d4ad88e2c2b3 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -69,14 +69,11 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr subgraph_context_.subgraph_name); model_stream.reset(); // Delete stream after it is no longer needed } else { - std::shared_ptr ov_model; - { - const std::string model = model_proto->SerializeAsString(); - if (!subgraph_context.has_dynamic_input_shape) { - delete model_proto.release(); - } - ov_model = CreateOVModel(model, session_context_, const_outputs_map_); + std::string model = model_proto->SerializeAsString(); + if (!subgraph_context.has_dynamic_input_shape) { + model_proto.reset() } + auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); LOGS_DEFAULT(INFO) << log_tag << "IO Buffering Enabled"; exe_network_ = OVCore::Get()->CompileModel( ov_model, remote_context_, subgraph_context_.subgraph_name); @@ -108,14 +105,11 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr subgraph_context_.subgraph_name); } else { // For all other types use ov::ov_core read_model() to generate OV IR // followed by ov::ov_core compile_model() - std::shared_ptr ov_model; - { - const std::string model = model_proto->SerializeAsString(); - if (!subgraph_context.has_dynamic_input_shape) { - delete model_proto.release(); - } - ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); + std::string model = model_proto->SerializeAsString(); + if (!subgraph_context.has_dynamic_input_shape) { + model_proto.reset(); } + auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); exe_network_ = OVCore::Get()->CompileModel( ov_model, hw_target, device_config, subgraph_context_.subgraph_name); } diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 37f9e1c4e9201..9208f6a76e0bc 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -46,9 +46,9 @@ void printDebugInfo(const ov::CompiledModel& obj) { } #endif -std::shared_ptr OVCore::ReadModel(const std::string& model, const std::string& model_path) { +std::shared_ptr OVCore::ReadModel(std::string&& model, const std::string& model_path) { try { - std::istringstream modelStringStream(model); + std::istringstream modelStringStream(std::move(model)); std::istream& modelStream = modelStringStream; // Try to load with FrontEndManager ov::frontend::FrontEndManager manager; diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 0ed51394a6ffa..f58b05e6017ec 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -67,7 +67,7 @@ struct OVCore : WeakSingleton { ov::Core core; // OV Interface For Reading Model - std::shared_ptr ReadModel(const std::string& model_stream, const std::string& model_path); + std::shared_ptr ReadModel(std::string&& model_stream, const std::string& model_path); // OV Interface for Compiling OV Model Type OVExeNetwork CompileModel(std::shared_ptr& ie_cnn_network, From a6cdf62176c116e3a1e07f6cec1681c041d653b9 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Mon, 10 Mar 2025 12:12:14 +0530 Subject: [PATCH 008/138] Rebasing with msft commits (#607) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix flash attention for GQA (Phi4) (#23850) ### Description This change fixes GQA for Flash Attention on Nvidia GPUs. The root cause appears to be `k_start + capped_sg_id < seq_causal_length` check. This is either because, a. seq_causal_length varies per lane, so the check becomes non uniform control flow, which is having interactions with subgroupShuffle. or b. The check itself is incorrect and is wiping out values of v based on the source lane's seq_causal_length. While in actualness values of v need to be causal as per the lane that is going to multiply it with qkt. qkt is already causal because earlier values of qk for out of bounds k are set to min_value, and exp(<-4) are 0. This fix works by removing that causal check and relying on the qk being wiped out earlier. The documentation for causality behavior for GQA is missing to determine which of this reason is the true reason. Prior to this prompts with sequence length > 16 < 32 or 1k would break with Phi 4 but smaller prompts would work. Tested on Intel Alderlake, Nvidia 4070. * Model Builder API (#23223) ### Description Supports creating a model programmatically using the ORT C or C++ API. Supports augmenting an existing model to add nodes. ### Motivation and Context * Fix typo: change `Upample` to `Upsample`. (#23838) ### Description Fixed a typo in function names related to the Upsample CUDA kernel. Changed incorrect spelling Upample to Upsample across relevant functions. ### Motivation and Context This change is necessary to maintain consistency and prevent potential confusion caused by incorrect function names. * [doc] Fix typos in csharp/src/Microsoft.ML.OnnxRuntime/ (#23848) ### Description Fix typos in csharp/src/Microsoft.ML.OnnxRuntime/ ### Motivation and Context * Quant tool: Consistent `get_qdq_config` and `get_qnn_qdq_config` behavior (#23856) * Change the logic to generate the default ep context file name (#23788) Change the logic to generate the default ep context file name ### Description Applies to all EPs: replace the .onnx to _ctx.onnx, instead of directly append extra string _ctx.onnx to existing model path. In QNN EP, also make the context binary .bin file shorter by removing QNNExecutionProvider_ from the file name. * Make Nuget QNN package pipeline 1ES compliant (#23805) ### Description Make [QNN_Nuget_Windows](https://aiinfra.visualstudio.com/Lotus/_build?definitionId=1234)1ES compliant ### Motivation and Context * [js/common] allows using Uint16Array as data for float16 tensor (#23827) ### Description Resolve #23817 ### Motivation and Context * [js/webgpu] Reland the optimization of ConvTranspose (#23858) This PR fixes the errors in the ConvTranspose optimization and adds tests to ensure the correctness of the implementation. * [OpenVINO] Fix a build warning (#23877) ### Description Fix a warning with std::move usage ### Motivation and Context Possibly allow building without --compile_no_warning_as_error flag * Change gsl::byte to std::byte (#23872) To be compatible with the latest GSL library. Without this fix we will get: ``` onnxruntime\core\providers\cpu\controlflow\loop.cc(247): error C4996: 'gsl::byte': Use std::byte instead. ``` * Allow using extended minimal build for several EPs (#23834) ### Description #### Background From code search, the following EPs use `onnxruntime::GetCpuPreferredNodes()` in their `GetCapabilities()` methods: - CANN - CUDA - DML - JS - ROCM - WebGPU However, the source file that implements `onnxruntime::GetCpuPreferredNodes()` is excluded when minimal build is ON: https://github.com/microsoft/onnxruntime/blob/6df0973e58ba5399fcaa98686f70ed9a9e59aaef/cmake/onnxruntime_framework.cmake#L38-L42 This means that all EPs mentioned above is not able to compile with minimal build. #### Solution The excluded file `core/framework/fallback_cpu_capability.cc` cannot build in minimal build because some of its dependencies are not included in the minimal build. However, in extended minimal build mode, all dependencies are available. This PR looses the restrict and allows to compile this file when it is extended minimal build. After this change, those EPs are able to compile in extended minimal build. * Add dawn to ThirdPartyNotices (#23876) ### Description Add `dawn` to ThirdPartyNotices. * Enable QNN EP weight sharing generation using public API (#23702) ### Description Enable QNN EP weight sharing generation using public API instead of internal interfaces, so that user can integrate into their own toolchain. The change is to share the QnnBackendManager across ORT sessions if ep.share_ep_contexts is enabled. And there is extra option to end the share so that we know when to remove the shared QnnBackendManager from the singleton. Change the tool name from onnxruntime_qnn_ctx_gen to ep_weight_sharing_ctx_gen, so that it can be shared for other EPs. * [QNN-EP]: Fix inference failures while running with htp_shared_memory (#23892) ### Description When using the enable_htp_shared_memory feature, we see that the address of the buffer passed to rpcmem_free is incorrect. So the rpc buffers are not freed leading to memory exhaustion. ### Motivation and Context When using the enable_htp_shared_memory_allocator feature for QNN in GenAI extensions, it leads to inference failures during the second prompt. As GenAI memory asks are higher, it surfaces sooner in gen AI use cases. Co-authored-by: Ashish Garg * Fix enable_pix_capture build for WebGPU (#23857) The build option --enable_pix_capture is broken. This fixes the problem. --------- Co-authored-by: wp * [WebGPU-EP Native] Add ReduceMean (#23860) ### Description ### Motivation and Context * [WebGPU EP] introduce BiasAdd contrib op (#23861) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Dynamo export and improve benchmark script for SAM2 encoder (#23887) ### Description * Add dynamo export for Sam2 image encoder * Verify fp32 onnx model with CPU EP (to avoid error message from TRT EP). * Update benchmark script: - output ORT profiling - output torch compiled code and unique kernel name for compiled kernel - add an option for nightly package installation - uninstall existing ort packages before installing The node metadata of dynamo exported model can help mapping node in onnx model back to pytorch modeling script. Currently, the graph optimization is not done on dynamo exported model, so it is experimental right now. ### Motivation and Context To support profiling of torch compiled CUDA kernel. * [js/web] improve workaround for bundlers (#23902) ### Description This PR improves the workaround for bundlers in onnxruntime-web. Specifically, the following changes have been made: - Use [this workaround](https://github.com/xenova/onnxruntime/commit/9c50aa2c63bad4cb73ad77ff1c43e0c43da0907f) as suggested by @xenova in https://github.com/huggingface/transformers.js/pull/1161#issuecomment-2695785730 - Use `url > "file:" && url < "file;"` instead of `url.startsWith("file:")` to allow minifiers to remove dead code correctly. This change allows to remove unnecessary dependencies of file parsed from `new URL("ort.bundle.min.js", import.meta.url)` in Vite, and optimize code like `if("file://filepath.js".startsWith("file:")) {do_sth1(); } else {do_sth2();}` into `do_sth1()` for webpack/terser usages. Resolves https://github.com/huggingface/transformers.js/pull/1161 * [webgpu] Restore MatMulNBits workgroup size for Phi-3.5 (#23349) ### Description This change restores the MatMulNBits workgroup size from (8, 8, 1) back to (16, 8, 1) to resolve a performance regression observed on Intel iGPUs during token generation (M=1). ### Motivation and Context As above. Signed-off-by: Jianhui Dai * [webgpu] support Pad operator (#23141) ### Description ### Motivation and Context * [WebNN] Accept Float16Array for float16 data type if it is available (#23894) Float16Array is now shipping and WebNN Chromium implementation has accepted it. We should allow it in WebNN EP as well. * Ensure that the 'cmake_minimum_required' is version 3.5 or greater (#23888) ### Description CMake 4.0 release candidate 2.0 is available, and it cannot compile all of OnnxRuntime out-of-the-box. There's portions of the OnnxRuntime codebase that specify a `cmake_minimum_required` version of 3.0, and CMake 4.0 has removed support for compatibility with CMake < 3.5 - the following error is reported: ``` CMake Error at winml_sdk_helpers.cmake:4 (cmake_minimum_required): Compatibility with CMake < 3.5 has been removed from CMake. Update the VERSION argument value. Or, use the ... syntax to tell CMake that the project requires at least but has been updated to work with policies introduced by or earlier. Or, add -DCMAKE_POLICY_VERSION_MINIMUM=3.5 to try configuring anyway. ``` Since CMake 3.5 appears to have shipped in 2016, it seems reasonable to set that as a minimum version to fix the error. The root CMakeLists.txt does ask for a minimum version of 3.28, so we could snap to that, but I'm still ramping up on the build, so wanted to propose a minimally sufficient fix. ### Motivation and Context Being able to build with the latest CMake - when it ships - reduces the barrier to entry to building OnnxRuntime, and allows the OnnxRuntime to leverage the latest and greatest tooling. * WebGPU: Remove deprecated subgroups-f16 from WebGPU native and JS EP (#23898) This PR removes the deprecated subgroups-f16 from WebGPU native and JS EP, and also remove the unused deviceInfo in WebGPU JS EP. * [JSEP/WebGPU] Fixed error in softmax dispatch. (#23906) ### Description Fixed an error softmax dispatch ### Motivation and Context Produce expected results for LlaMA model * enable WebGPU EP in WebAssembly build (#23913) ### Description This PR is the first step for migrating the webgpu backend of onnxruntime-web from JSEP based to WebGPU EP based. In this change, we enable building WebGPU EP in a wasm build (ie. `--build_wasm` `--use_webgpu` `--use_jsep`). However, the old build flags should still keep previous behavior. * Adding OpenVINO Windows CI Pipeline (#23919) ### Description Enable an OpenVINO Windows CI pipeline. This includes: - Downloading the OpenVINO toolkit for Windows from an external source. - Setting up OpenVINO environment variables. - Building the ONNX Runtime OpenVINO Execution Provider. - Running unit tests. ### Motivation and Context This change is required to run checks on precommit and commit in the ONNX Runtime project. It ensures that the code is tested with the OpenVINO toolkit on Windows, improving the reliability and compatibility of the project. * [WebGPU EP] SoftMax Implementation (#23538) Increase coverage for WebGPU Op * Exclude MAUI projects from GPU C# packaging builds (#23923) ### Description Use 'desktop only' solution in GPU C# packaging builds. We don't need to include any MAUI support for those builds. ### Motivation and Context * Support all block sizes that are multiples of 32 for DP4A (#23907) ### Description Simple change 1. The DP4A shader actually supports all block sizes that are multiples of 32, relaxing the restriction and making a small tweak to support sizes other than 32. 2. Moved the shader to a separate file for maintainability. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Example custom op with output type inferencing (#23916) ### Description Add example of a custom op that is required to do type inference for the output type for the model load to work. Also acts as an example of how to override an ONNX op with a custom implementation. ### Motivation and Context #23891 * Enabling L2+ Optimizations for EPs (#23517) There are some requirements to modify the graph which are specific to the EP/hardware. ORT has the hardcoded EP list for optimizations but that can't scale and it's hard be extended to enable EP custom optimizations. Here is the prototype to enable L2+ optimizations for EPs (The original overview is provided by @skottmckay) as well as the TRT EP implementation for the ConstantFoldingDQ optimization. Signatures for selection and optimization functions: ```` - Selection: std::function>(const GraphViewer&, const KeyValueConfig&)> - Optimization: std::function ```` GetCapability - call (new) provider bridge API to lookup pre-defined optimizer by name and get selection function - ComputeCapability.optimize_func, i.e. optimization function, would be set by the optimizer to the function that does the optimization - EP has to update the returning ComputeCapability to include the optimization ComputeCapability in nodes_to_optimize. So that later ORT can perform optimization/transformation accordingly. GraphPartitioner - After assigning the ComputeCapability to the EP and prior to Compile, if the ComputeCapability has nodes_to_optimize, iterate that list - optimization function needs to be called with - a mutable Graph instance - the ComputeCapability for the individual optimization - the overall ComputeCapability so it can be updated * fix binplace file in web pipeline (#23930) * Updated run_CIs_for_external_pr.py to support the Windows OpenVINO CI pipeline (#23931) * Fix ConvInteger handling of optional inputs. (#23935) ### Description Fix ConvInteger handling of optional inputs. Need to check Exists() and not just the number of inputs. ### Motivation and Context #23927 * Updated ov version in pipeline (#595) (#23882) ### Description This PR updates the OpenVINO version used in the pipeline from 2024.5.0 to 2025.0.0 Co-authored-by: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> * [AIX] External data handling (#23859) ### Description In BE system, model tensor data coming from external file is not handled properly. This was found during the debugging of (https://github.com/microsoft/onnxruntime-genai/issues/1104)(url) This PR changes do the endianness conversion of data loaded from external file in BE system. * Create a packaging pipeline for a custom nuget package (#23918) * Fix license in example test code. (#23936) * replace usage of gsl::narrow and gsl::narrow_cast in WebGPU EP (#23926) ### Description `gsl::narrow` does not work in no exception build. - use `onnxruntime::narrow` if necessary; - or change to `static_cast` if it's obviously safe. also apply the changes to usage of `gsl::narrow_cast`, which does not apply checks. * VCPKG improvement: set VCPKG_OSX_DEPLOYMENT_TARGET (#23933) ### Description 1. Set VCPKG_OSX_DEPLOYMENT_TARGET for macOS targets 2. Enable VCPKG in more pipelines. * Allow using a different version of flatbuffers when building with vcpkg (#23946) ### Description Allow using a different version of flatbuffers when building with vcpkg, so that users do not need to pin flatbuffer's version, which provides more flexibility in the build process. Delete utf8_range from the dependencies, because it is an indirect dependency of protobuf, which is already included in the build process. ### Motivation and Context * Make python package pipeline 1ES compliant (#23800) ### Description Make [Python packaging pipeline](https://aiinfra.visualstudio.com/530acbc4-21bc-487d-8cd8-348ff451d2ff/_build?definitionId=841) 1ES compliant ### Motivation and Context ### Checklist - [x] Make Onnxruntime-QNNEP-Windows-2022-CPU stateless * Delete ROCM Nuget Publishing Pipeline (#23948) * Bump SixLabors.ImageSharp from 2.1.9 to 2.1.10 in /csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample (#23924) Bumps [SixLabors.ImageSharp](https://github.com/SixLabors/ImageSharp) from 2.1.9 to 2.1.10.
Release notes

Sourced from SixLabors.ImageSharp's releases.

v2.1.10

What's Changed

Full Changelog: https://github.com/SixLabors/ImageSharp/compare/v2.1.9...v2.1.10

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=SixLabors.ImageSharp&package-manager=nuget&previous-version=2.1.9&new-version=2.1.10)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --------- Signed-off-by: Jianhui Dai Signed-off-by: dependabot[bot] Co-authored-by: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Co-authored-by: Scott McKay Co-authored-by: Seungtaek Kim Co-authored-by: co63oc Co-authored-by: Jambay Kinley Co-authored-by: Hector Li Co-authored-by: Jian Chen Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Jiajia Qin Co-authored-by: Alessio Soldano Co-authored-by: Changming Sun Co-authored-by: Ashish Garg Co-authored-by: Ashish Garg Co-authored-by: Jie Chen Co-authored-by: wp Co-authored-by: Satya Kumar Jandhyala Co-authored-by: Prathik Rao Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tianlei Wu Co-authored-by: Jianhui Dai Co-authored-by: xhcao Co-authored-by: Wanming Lin Co-authored-by: Mark Schofield Co-authored-by: jiangzhaoming Co-authored-by: Yi-Hong Lyu Co-authored-by: vraspar Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: saurabh Co-authored-by: Ranjit Ranjan <165394499+ranjitshs@users.noreply.github.com> Co-authored-by: Baiju Meswani Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- ThirdPartyNotices.txt | 35 + cmake/deps.txt | 1 - .../external/onnxruntime_external_deps.cmake | 54 +- cmake/nuget_helpers.cmake | 2 +- cmake/onnxruntime_framework.cmake | 5 +- cmake/onnxruntime_optimizer.cmake | 1 + cmake/onnxruntime_providers_js.cmake | 6 +- cmake/onnxruntime_python.cmake | 2 +- cmake/onnxruntime_session.cmake | 1 + cmake/onnxruntime_unittests.cmake | 43 +- cmake/onnxruntime_webassembly.cmake | 37 +- cmake/patches/dawn/dawn.patch | 113 ++- cmake/winml_sdk_helpers.cmake | 2 +- ...oft.ML.OnnxRuntime.FasterRcnnSample.csproj | 2 +- .../ManagedProjections.shared.cs | 3 +- .../NativeMethods.shared.cs | 4 +- .../core/framework/execution_provider.h | 16 + include/onnxruntime/core/graph/graph.h | 32 +- include/onnxruntime/core/graph/graph_viewer.h | 6 + .../core/graph/indexed_sub_graph.h | 6 + .../core/session/onnxruntime_c_api.h | 491 +++++++++++- .../core/session/onnxruntime_cxx_api.h | 261 ++++++- .../core/session/onnxruntime_cxx_inline.h | 350 ++++++++- .../onnxruntime_session_options_config_keys.h | 5 +- js/build_webgpu.bat | 79 ++ js/common/lib/tensor-impl-type-mapping.ts | 9 +- js/common/lib/tensor-impl.ts | 7 + js/common/package.json | 3 +- js/common/test/unit-tests/common.ts | 5 +- .../test/unit-tests/tensor/constructor-f16.ts | 62 ++ .../unit-tests/tensor/constructor-type.ts | 8 - js/web/lib/build-def.d.ts | 7 + js/web/lib/wasm/jsep/backend-webgpu.ts | 28 +- js/web/lib/wasm/jsep/backend-webnn.ts | 3 +- js/web/lib/wasm/jsep/init.ts | 144 ++-- .../ops/3rd-party/conv_backprop_webgpu.ts | 96 ++- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 2 +- .../lib/wasm/jsep/webgpu/program-manager.ts | 1 - js/web/lib/wasm/jsep/webgpu/types.ts | 10 - js/web/lib/wasm/proxy-wrapper.ts | 8 +- js/web/lib/wasm/session-options.ts | 116 ++- js/web/lib/wasm/wasm-core-impl.ts | 97 ++- js/web/lib/wasm/wasm-types.ts | 68 +- js/web/lib/wasm/wasm-utils-import.ts | 50 +- js/web/script/build.ts | 36 +- js/web/test/data/ops/conv-transpose.jsonc | 122 +++ js/web/test/e2e/exports/main.js | 11 +- js/web/test/e2e/exports/test.js | 22 + .../contrib_ops/webgpu/bert/bias_add.cc | 80 ++ .../contrib_ops/webgpu/bert/bias_add.h | 32 + .../contrib_ops/webgpu/bert/fast_gelu.cc | 4 +- .../webgpu/bert/flash_attention.cc | 4 +- .../webgpu/bert/rotary_embedding.cc | 14 +- .../webgpu/bert/skip_layer_norm.cc | 4 +- .../webgpu/quantization/dp4a_matmul_nbits.cc | 326 ++++++++ .../webgpu/quantization/dp4a_matmul_nbits.h | 56 ++ .../webgpu/quantization/matmul_nbits.cc | 322 +------- .../webgpu/quantization/matmul_nbits.h | 19 - .../subgroup_matrix_matmul_nbits.cc | 8 +- .../webgpu/webgpu_contrib_kernels.cc | 4 +- .../core/framework/compute_capability.h | 20 + .../core/framework/execution_provider.cc | 1 + .../core/framework/external_data_loader.cc | 7 +- .../core/framework/external_data_loader.h | 2 +- .../core/framework/fallback_cpu_capability.cc | 4 + .../core/framework/fallback_cpu_capability.h | 4 + .../core/framework/graph_partitioner.cc | 248 ++++--- .../core/framework/graph_partitioner.h | 9 +- .../core/framework/onnxruntime_typeinfo.cc | 71 +- .../core/framework/onnxruntime_typeinfo.h | 2 +- .../core/framework/session_state_utils.cc | 35 +- .../core/framework/tensor_type_and_shape.cc | 35 +- .../core/framework/tensorprotoutils.cc | 29 +- onnxruntime/core/framework/tensorprotoutils.h | 10 +- onnxruntime/core/graph/graph.cc | 295 +++++++- .../core/graph/graph_flatbuffers_utils.cc | 14 +- onnxruntime/core/graph/model.cc | 32 +- onnxruntime/core/graph/model.h | 8 +- .../core/graph/model_editor_api_types.h | 47 ++ .../core/optimizer/constant_folding.cc | 13 +- onnxruntime/core/optimizer/constant_folding.h | 18 + .../optimizer/graph_optimizer_registry.cc | 49 ++ .../core/optimizer/graph_optimizer_registry.h | 77 ++ .../constant_folding_dq_node.cc | 26 + .../constant_folding_dq_node.h | 37 + .../selection_and_optimization_func.cc | 99 +++ .../selection_and_optimization_func.h | 31 + .../providers/acl/acl_execution_provider.cc | 1 + .../providers/acl/acl_execution_provider.h | 1 + .../providers/cann/cann_execution_provider.cc | 1 + .../providers/cann/cann_execution_provider.h | 1 + .../coreml/coreml_execution_provider.cc | 1 + .../coreml/coreml_execution_provider.h | 1 + .../core/providers/cpu/controlflow/loop.cc | 4 +- .../cpu/quantization/conv_integer.cc | 7 +- .../core/providers/cuda/controlflow/loop.cc | 4 +- .../providers/cuda/cuda_execution_provider.cc | 1 + .../providers/cuda/cuda_execution_provider.h | 1 + .../core/providers/cuda/tensor/upsample.cc | 20 +- .../providers/cuda/tensor/upsample_impl.cu | 94 +-- .../providers/cuda/tensor/upsample_impl.h | 20 +- .../src/ExecutionProvider.cpp | 6 +- .../src/ExecutionProvider.h | 3 + .../providers/dnnl/dnnl_execution_provider.cc | 1 + .../providers/dnnl/dnnl_execution_provider.h | 1 + .../providers/js/js_execution_provider.cc | 1 + .../core/providers/js/js_execution_provider.h | 1 + .../migraphx/migraphx_execution_provider.cc | 1 + .../migraphx/migraphx_execution_provider.h | 1 + .../nnapi_builtin/nnapi_execution_provider.cc | 1 + .../nnapi_builtin/nnapi_execution_provider.h | 1 + .../openvino/backends/basic_backend.cc | 2 +- .../openvino/openvino_execution_provider.cc | 1 + .../openvino/openvino_execution_provider.h | 1 + .../qnn/builder/onnx_ctx_model_helper.cc | 38 +- .../qnn/builder/onnx_ctx_model_helper.h | 7 +- .../qnn/builder/qnn_backend_manager.cc | 2 + .../core/providers/qnn/qnn_allocator.cc | 4 +- .../providers/qnn/qnn_execution_provider.cc | 73 +- .../providers/qnn/qnn_execution_provider.h | 2 + .../core/providers/qnn/shared_context.h | 26 + .../rknpu/rknpu_execution_provider.cc | 1 + .../rknpu/rknpu_execution_provider.h | 1 + .../providers/rocm/rocm_execution_provider.cc | 1 + .../providers/rocm/rocm_execution_provider.h | 1 + .../providers/shared_library/provider_api.h | 1 + .../provider_bridge_provider.cc | 3 +- .../shared_library/provider_interfaces.h | 9 + .../shared_library/provider_wrappedtypes.h | 3 + .../providers/snpe/snpe_execution_provider.cc | 1 + .../providers/snpe/snpe_execution_provider.h | 1 + .../tensorrt/tensorrt_execution_provider.cc | 55 +- .../tensorrt/tensorrt_execution_provider.h | 31 + .../tensorrt_execution_provider_helper.cc | 129 ++++ .../vitisai/vitisai_execution_provider.cc | 2 +- .../vitisai/vitisai_execution_provider.h | 1 + .../vsinpu/vsinpu_execution_provider.cc | 1 + .../vsinpu/vsinpu_execution_provider.h | 1 + .../providers/webgpu/external_data_loader.cc | 40 + .../providers/webgpu/external_data_loader.h | 30 + .../core/providers/webgpu/generator/range.cc | 2 +- .../webgpu/math/binary_elementwise_ops.cc | 2 +- .../core/providers/webgpu/math/softmax.cc | 238 ++++++ .../core/providers/webgpu/math/softmax.h | 54 ++ .../webgpu/math/unary_elementwise_ops.cc | 2 +- .../core/providers/webgpu/nn/layer_norm.cc | 6 +- onnxruntime/core/providers/webgpu/program.cc | 20 + onnxruntime/core/providers/webgpu/program.h | 1 + .../core/providers/webgpu/program_manager.cc | 10 +- .../webgpu/reduction/reduction_ops.cc | 168 +++++ .../webgpu/reduction/reduction_ops.h | 62 ++ .../core/providers/webgpu/shader_helper.cc | 3 - .../core/providers/webgpu/shader_variable.cc | 2 +- .../core/providers/webgpu/tensor/cast.cc | 2 +- .../core/providers/webgpu/tensor/cast.h | 2 +- .../core/providers/webgpu/tensor/concat.cc | 2 +- .../core/providers/webgpu/tensor/expand.cc | 2 +- .../core/providers/webgpu/tensor/gather.cc | 2 +- .../core/providers/webgpu/tensor/pad.cc | 261 +++++++ .../core/providers/webgpu/tensor/pad.h | 40 + .../providers/webgpu/tensor/resize_impl.cc | 8 +- .../core/providers/webgpu/tensor/split.cc | 6 +- .../core/providers/webgpu/tensor/transpose.cc | 62 +- .../core/providers/webgpu/tensor/transpose.h | 2 + .../core/providers/webgpu/tensor/where.cc | 2 +- .../core/providers/webgpu/webgpu_context.cc | 61 +- .../webgpu/webgpu_execution_provider.cc | 38 +- .../webgpu/webgpu_execution_provider.h | 4 + .../webgpu/webgpu_pix_frame_generator.cc | 4 +- .../webgpu/webgpu_pix_frame_generator.h | 2 +- .../webgpu/webgpu_provider_factory.cc | 6 + .../impl/rotaryEmbedding_op_builder.cc | 14 +- .../providers/webnn/builders/model_builder.cc | 6 +- .../providers/webnn/builders/model_builder.h | 10 +- .../webnn/webnn_execution_provider.cc | 1 + .../webnn/webnn_execution_provider.h | 1 + .../xnnpack/xnnpack_execution_provider.cc | 1 + .../xnnpack/xnnpack_execution_provider.h | 1 + .../core/session/abi_session_options.cc | 17 +- onnxruntime/core/session/api_utils.cc | 25 - onnxruntime/core/session/api_utils.h | 9 - onnxruntime/core/session/custom_ops.cc | 25 +- onnxruntime/core/session/inference_session.cc | 78 +- onnxruntime/core/session/inference_session.h | 35 +- onnxruntime/core/session/model_editor_api.h | 65 ++ .../core/session/model_editor_c_api.cc | 358 +++++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 328 ++++---- onnxruntime/core/session/ort_apis.h | 16 + .../core/session/provider_bridge_ort.cc | 23 +- onnxruntime/core/session/utils.cc | 125 ++++ onnxruntime/core/session/utils.h | 28 + .../execution_providers/qnn/quant_config.py | 6 +- .../python/tools/quantization/quantize.py | 32 +- .../tools/transformers/models/sam2/README.md | 31 +- .../models/sam2/benchmark_sam2.py | 15 +- .../models/sam2/benchmark_sam2.sh | 310 +++++--- .../models/sam2/convert_to_onnx.py | 14 +- .../transformers/models/sam2/image_decoder.py | 2 +- .../transformers/models/sam2/image_encoder.py | 74 +- .../transformers/models/sam2/mask_decoder.py | 2 +- .../models/sam2/prompt_encoder.py | 2 +- .../README.md | 10 +- .../command_args_parser.cc | 47 +- .../command_args_parser.h | 0 .../test/ep_weight_sharing_ctx_gen/main.cc | 247 ++++++ .../test_configuration.h | 7 +- .../test/framework/inference_session_test.cc | 1 + .../test/framework/session_state_test.cc | 27 +- onnxruntime/test/framework/type_info_test.cc | 26 +- onnxruntime/test/providers/base_tester.cc | 6 +- onnxruntime/test/providers/base_tester.h | 6 +- .../test/providers/cpu/math/softmax_test.cc | 13 +- .../providers/cpu/nn/conv_integer_test.cc | 40 + .../internal_testing_execution_provider.cc | 1 + .../internal_testing_execution_provider.h | 1 + .../test/providers/qnn/qnn_ep_context_test.cc | 267 ++++--- .../test/providers/qnn/qnn_test_utils.cc | 7 +- .../quantization/test_get_qdq_config.py | 56 ++ onnxruntime/test/qnn_ctx_gen/main.cc | 250 ------- .../test/shared_lib/custom_op_utils.cc | 20 + onnxruntime/test/shared_lib/custom_op_utils.h | 67 +- onnxruntime/test/shared_lib/test_inference.cc | 192 +++-- .../test/shared_lib/test_model_builder_api.cc | 701 ++++++++++++++++++ .../test/shared_lib/test_ort_format_models.cc | 14 +- onnxruntime/test/shared_lib/utils.h | 52 ++ .../test/testdata/cast_float_to_double.onnx | Bin 0 -> 136 bytes .../my_execution_provider.cc | 2 +- .../my_execution_provider.h | 2 +- onnxruntime/wasm/api.cc | 26 +- onnxruntime/wasm/api.h | 24 +- onnxruntime/wasm/js_post_js.js | 2 - onnxruntime/wasm/js_post_js_64.js | 2 - onnxruntime/wasm/post-webgpu.js | 261 +++++++ onnxruntime/wasm/pre-async.js | 132 ++++ onnxruntime/wasm/pre-jsep.js | 308 +++----- onnxruntime/wasm/pre.js | 15 +- setup.py | 2 +- tools/ci_build/build.py | 21 +- .../custom-nuget-packaging-pipeline.yml | 142 ++++ .../py-package-test-pipeline.yml | 2 + .../azure-pipelines/py-packaging-pipeline.yml | 50 +- .../qnn-ep-nuget-packaging-pipeline.yml | 148 ++-- .../rocm-nuget-packaging-pipeline.yml | 339 --------- .../rocm-publish-nuget-pipeline.yml | 21 - .../stages/nuget-cuda-packaging-stage.yml | 15 +- .../stages/nuget-qnn-packaging-stage.yml | 76 ++ .../stages/py-cpu-packaging-stage.yml | 124 ++-- ...acts-package-and-publish-steps-windows.yml | 16 + .../templates/jobs/download_win_openvino.yml | 64 ++ .../templates/linux-web-init-and-check.yml | 8 + .../templates/py-linux-qnn.yml | 118 +-- .../azure-pipelines/templates/py-linux.yml | 144 ++-- .../templates/py-package-smoking-test.yml | 13 +- .../templates/py-packaging-linux-test-cpu.yml | 18 +- .../templates/py-win-arm64-qnn.yml | 273 +++---- .../templates/py-win-arm64ec-qnn.yml | 241 +++--- .../templates/py-win-x64-qnn.yml | 21 +- .../azure-pipelines/templates/qnn-ep-win.yml | 260 ++++--- .../templates/react-native-ci.yml | 12 +- .../azure-pipelines/templates/web-ci.yml | 3 - .../azure-pipelines/templates/win-ci.yml | 2 +- .../azure-pipelines/templates/win-web-ci.yml | 12 +- .../templates/win-web-multi-browsers.yml | 12 +- .../templates/windowsai-steps.yml | 2 +- .../win-gpu-webgpu-ci-pipeline.yml | 28 + .../win-openvino-ci-pipeline.yml | 116 +++ .../win-qnn-arm64-ci-pipeline.yml | 2 +- .../azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- tools/ci_build/set-trigger-rules.py | 3 +- .../nuget/generate_nuspec_for_custom_nuget.py | 150 ++++ tools/python/run_CIs_for_external_pr.py | 1 + tools/python/util/__init__.py | 3 +- tools/python/util/vcpkg_helpers.py | 78 +- winml/adapter/winml_adapter_model.cpp | 18 +- 274 files changed, 10047 insertions(+), 3285 deletions(-) create mode 100644 js/build_webgpu.bat create mode 100644 js/common/test/unit-tests/tensor/constructor-f16.ts create mode 100644 onnxruntime/contrib_ops/webgpu/bert/bias_add.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/bias_add.h create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h create mode 100644 onnxruntime/core/graph/model_editor_api_types.h create mode 100644 onnxruntime/core/optimizer/graph_optimizer_registry.cc create mode 100644 onnxruntime/core/optimizer/graph_optimizer_registry.h create mode 100644 onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc create mode 100644 onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h create mode 100644 onnxruntime/core/optimizer/selection_and_optimization_func.cc create mode 100644 onnxruntime/core/optimizer/selection_and_optimization_func.h create mode 100644 onnxruntime/core/providers/webgpu/external_data_loader.cc create mode 100644 onnxruntime/core/providers/webgpu/external_data_loader.h create mode 100644 onnxruntime/core/providers/webgpu/math/softmax.cc create mode 100644 onnxruntime/core/providers/webgpu/math/softmax.h create mode 100644 onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc create mode 100644 onnxruntime/core/providers/webgpu/reduction/reduction_ops.h create mode 100644 onnxruntime/core/providers/webgpu/tensor/pad.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/pad.h delete mode 100644 onnxruntime/core/session/api_utils.cc delete mode 100644 onnxruntime/core/session/api_utils.h create mode 100644 onnxruntime/core/session/model_editor_api.h create mode 100644 onnxruntime/core/session/model_editor_c_api.cc create mode 100644 onnxruntime/core/session/utils.cc create mode 100644 onnxruntime/core/session/utils.h rename onnxruntime/test/{qnn_ctx_gen => ep_weight_sharing_ctx_gen}/README.md (82%) rename onnxruntime/test/{qnn_ctx_gen => ep_weight_sharing_ctx_gen}/command_args_parser.cc (68%) rename onnxruntime/test/{qnn_ctx_gen => ep_weight_sharing_ctx_gen}/command_args_parser.h (100%) create mode 100644 onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc rename onnxruntime/test/{qnn_ctx_gen => ep_weight_sharing_ctx_gen}/test_configuration.h (75%) delete mode 100644 onnxruntime/test/qnn_ctx_gen/main.cc create mode 100644 onnxruntime/test/shared_lib/test_model_builder_api.cc create mode 100644 onnxruntime/test/testdata/cast_float_to_double.onnx create mode 100644 onnxruntime/wasm/post-webgpu.js create mode 100644 onnxruntime/wasm/pre-async.js create mode 100644 tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml delete mode 100644 tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml delete mode 100644 tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml create mode 100644 tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml create mode 100644 tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml create mode 100644 tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml create mode 100644 tools/nuget/generate_nuspec_for_custom_nuget.py diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 26084ab42ec1c..a449e42f6bf19 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -6045,3 +6045,38 @@ https://github.com/intel/neural-speed terms, and open source software license terms. These separate license terms govern your use of the third party programs as set forth in the "THIRD-PARTY-PROGRAMS" file. + +_____ + +dawn + +https://dawn.googlesource.com/dawn + + BSD 3-Clause License + + Copyright 2017-2023 The Dawn & Tint Authors + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/cmake/deps.txt b/cmake/deps.txt index d0bab93d3c16f..c7db8ef51505d 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -53,7 +53,6 @@ re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cd safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.zip;e49b2b964163d27765a5002d210a2f3c73771835 -utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0c12f53da76d0c31b03b9f0f8ec8f3b4.zip;239063aee4946a9af147b473a4c3da78ba7413b4 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index ebf20ab21bbd2..a477d6edb3a3f 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -107,23 +107,6 @@ if(onnxruntime_USE_MIMALLOC) FetchContent_MakeAvailable(mimalloc) endif() -#Protobuf depends on utf8_range -onnxruntime_fetchcontent_declare( - utf8_range - URL ${DEP_URL_utf8_range} - URL_HASH SHA1=${DEP_SHA1_utf8_range} - EXCLUDE_FROM_ALL - FIND_PACKAGE_ARGS NAMES utf8_range -) - -set(utf8_range_ENABLE_TESTS OFF CACHE BOOL "Build test suite" FORCE) -set(utf8_range_ENABLE_INSTALL OFF CACHE BOOL "Configure installation" FORCE) - -# The next line will generate an error message "fatal: not a git repository", but it is ok. It is from flatbuffers -onnxruntime_fetchcontent_makeavailable(utf8_range) -# protobuf's cmake/utf8_range.cmake has the following line -include_directories(${utf8_range_SOURCE_DIR}) - # Download a protoc binary from Internet if needed if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE AND NOT onnxruntime_USE_VCPKG) # This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually @@ -304,7 +287,7 @@ if(NOT TARGET Boost::mp11) EXCLUDE_FROM_ALL FIND_PACKAGE_ARGS NAMES Boost ) - onnxruntime_fetchcontent_makeavailable(mp11) + onnxruntime_fetchcontent_makeavailable(mp11) if(NOT TARGET Boost::mp11) add_library(Boost::mp11 ALIAS Boost::headers) endif() @@ -442,6 +425,9 @@ target_include_directories(safeint_interface INTERFACE ${safeint_SOURCE_DIR}) # Flatbuffers +if(onnxruntime_USE_VCPKG) + find_package(flatbuffers REQUIRED) +else() # We do not need to build flatc for iOS or Android Cross Compile if (CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(FLATBUFFERS_BUILD_FLATC OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATC" FORCE) @@ -492,6 +478,7 @@ namespace std { using ::getenv; } endif() endif() endif() +endif() # ONNX if (NOT onnxruntime_USE_FULL_PROTOBUF) @@ -672,17 +659,10 @@ if (onnxruntime_USE_WEBGPU) # disable things we don't use set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) - set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE) - set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE) - set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE) - set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE) set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE) set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE) set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE) set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled. @@ -732,7 +712,29 @@ if (onnxruntime_USE_WEBGPU) # # if we need to apply patches in the future, we can uncomment the following line. # # The dawn.patch contains the following changes: - # - https://dawn-review.googlesource.com/c/dawn/+/225514 + # + # - (public) CMake fix to support Emscripten v4.0.3+ + # This change allows Dawn to find the file "gen_struct_info.py" in the correct location. + # https://dawn-review.googlesource.com/c/dawn/+/225514 + # + # - (public) Fix emwgpu C++ implementation for buffer destroy + # In native implementation, wgpuBufferRelease will trigger the buffer destroy (if refcount decreased to 0). But + # in emwgpu implementation, the buffer destroy won't happen. This change fixes the bug. + # https://dawn-review.googlesource.com/c/dawn/+/226315 + # + # - (private) Allow "external" buffer in emwgpu C++ implementation + # This change allows WGPUBufferImpl to destroy the buffer when the refcount decreased to 0 only for non-external + # buffer. + # "external buffer" means the GPUBuffer instance created in JavaScript and imported to C++ by `importJsBuffer`. + # + # - (private) Remove hard-coded CMAKE_OSX_DEPLOYMENT_TARGET in Dawn's CMake files + # https://github.com/microsoft/onnxruntime/pull/23729 + # + # - (private) Fix external ref count for "external" device in emwgpu C++ implementation + # This change fixes the incorrect external ref count for class WGPUDeviceImpl when used with "external" device. + # "external device" means the GPUDevice instance created in JavaScript and imported to C++ by `importJsDevice`. + # + # PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch EXCLUDE_FROM_ALL ) diff --git a/cmake/nuget_helpers.cmake b/cmake/nuget_helpers.cmake index 22143ac422e9f..b066d1e9fb50e 100644 --- a/cmake/nuget_helpers.cmake +++ b/cmake/nuget_helpers.cmake @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -cmake_minimum_required(VERSION 3.0) +cmake_minimum_required(VERSION 3.5) # Determines the version of a native nuget package from the root packages.config. # diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index b1e98a9e0411c..9c9a25f8ee77e 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -36,10 +36,7 @@ elseif(onnxruntime_ENABLE_TRITON) endif() if (onnxruntime_MINIMAL_BUILD) - set(onnxruntime_framework_src_exclude - "${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.h" - "${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.cc" - ) + set(onnxruntime_framework_src_exclude) # custom ops support must be explicitly enabled in a minimal build. exclude if not. if (NOT onnxruntime_MINIMAL_BUILD_CUSTOM_OPS) diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 9d680cd04af10..173c872d4cc06 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -9,6 +9,7 @@ if (onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_optimizer_src_patterns "${ONNXRUNTIME_INCLUDE_DIR}/core/optimizer/graph_transformer.h" "${ONNXRUNTIME_ROOT}/core/optimizer/graph_transformer.cc" + "${ONNXRUNTIME_ROOT}/core/optimizer/graph_optimizer_registry.cc" ) if (onnxruntime_EXTENDED_MINIMAL_BUILD) diff --git a/cmake/onnxruntime_providers_js.cmake b/cmake/onnxruntime_providers_js.cmake index 9811eae611463..fefbab5082da4 100644 --- a/cmake/onnxruntime_providers_js.cmake +++ b/cmake/onnxruntime_providers_js.cmake @@ -1,6 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. + if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "JSEP can not be used in a basic minimal build. Please build with '--minimal_build extended'") + endif() + add_compile_definitions(USE_JSEP=1) file(GLOB_RECURSE onnxruntime_providers_js_cc_srcs @@ -18,4 +22,4 @@ onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11 Eigen3::Eigen ) - add_dependencies(onnxruntime_providers_js ${onnxruntime_EXTERNAL_DEPENDENCIES}) \ No newline at end of file + add_dependencies(onnxruntime_providers_js ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index aee6d2ff7655c..64b53c2912be0 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -1029,7 +1029,7 @@ if (onnxruntime_USE_QNN) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy - $ + $ $/onnxruntime/capi/ ) if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 3d63285d50e72..2c2c59091fae5 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -22,6 +22,7 @@ endif() if (onnxruntime_MINIMAL_BUILD) set(onnxruntime_session_src_exclude "${ONNXRUNTIME_ROOT}/core/session/provider_bridge_ort.cc" + "${ONNXRUNTIME_ROOT}/core/session/model_builder_c_api.cc" ) list(REMOVE_ITEM onnxruntime_session_srcs ${onnxruntime_session_src_exclude}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0916aeb3dd92c..87aee2a174fab 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -236,14 +236,14 @@ function(AddTest) ) endif() # Set test timeout to 3 hours. - set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 7200) + set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 10800) else() add_test(NAME ${_UT_TARGET} COMMAND ${_UT_TARGET} ${TEST_ARGS} WORKING_DIRECTORY $ ) # Set test timeout to 3 hours. - set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 7200) + set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 10800) endif() endif() endfunction(AddTest) @@ -503,6 +503,7 @@ set (onnxruntime_shared_lib_test_SRC if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc) + list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_builder_api.cc) endif() if(onnxruntime_RUN_ONNX_TESTS) @@ -1288,31 +1289,34 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if(onnxruntime_USE_QNN) #qnn ctx generator - set(onnxruntime_qnn_ctx_gen_src_dir ${TEST_SRC_DIR}/qnn_ctx_gen) - set(onnxruntime_qnn_ctx_gen_src_patterns - "${onnxruntime_qnn_ctx_gen_src_dir}/*.cc" - "${onnxruntime_qnn_ctx_gen_src_dir}/*.h") + set(ep_weight_sharing_ctx_gen_src_dir ${TEST_SRC_DIR}/ep_weight_sharing_ctx_gen) + set(ep_weight_sharing_ctx_gen_src_patterns + "${ep_weight_sharing_ctx_gen_src_dir}/*.cc" + "${ep_weight_sharing_ctx_gen_src_dir}/*.h") - file(GLOB onnxruntime_qnn_ctx_gen_src CONFIGURE_DEPENDS - ${onnxruntime_qnn_ctx_gen_src_patterns} + file(GLOB ep_weight_sharing_ctx_gen_src CONFIGURE_DEPENDS + ${ep_weight_sharing_ctx_gen_src_patterns} ) - onnxruntime_add_executable(onnxruntime_qnn_ctx_gen ${onnxruntime_qnn_ctx_gen_src}) - target_include_directories(onnxruntime_qnn_ctx_gen PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} - ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} - ${CMAKE_CURRENT_BINARY_DIR}) + onnxruntime_add_executable(ep_weight_sharing_ctx_gen ${ep_weight_sharing_ctx_gen_src}) + target_include_directories(ep_weight_sharing_ctx_gen PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) if (WIN32) - target_compile_options(onnxruntime_qnn_ctx_gen PRIVATE ${disabled_warnings}) + target_compile_options(ep_weight_sharing_ctx_gen PRIVATE ${disabled_warnings}) if (NOT DEFINED SYS_PATH_LIB) set(SYS_PATH_LIB shlwapi) endif() endif() - if(WIN32) - target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE debug dbghelp advapi32) + if (onnxruntime_BUILD_SHARED_LIB) + set(ep_weight_sharing_ctx_gen_libs onnxruntime_common onnxruntime ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE}) + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE ${ep_weight_sharing_ctx_gen_libs}) + if (WIN32) + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE debug dbghelp advapi32) + endif() + else() + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE onnxruntime_session ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE}) endif() - target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) - set_target_properties(onnxruntime_qnn_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") + set_target_properties(ep_weight_sharing_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") endif() # shared lib @@ -1359,14 +1363,19 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) + + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_ROOT}) + if (onnxruntime_USE_CUDA) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) endif() + if (onnxruntime_USE_ROCM) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 8106e46ccf580..f3afaf7033fd1 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -211,10 +211,14 @@ else() target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard) endif() + set(onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre.js") + + set(EXPORTED_FUNCTIONS "_malloc,_free") if (onnxruntime_USE_JSEP) - set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName") - else() - set(EXPORTED_FUNCTIONS "_malloc,_free") + string(APPEND EXPORTED_FUNCTIONS ",_JsepOutput,_JsepGetNodeName") + endif() + if (onnxruntime_USE_WEBGPU) + string(APPEND EXPORTED_FUNCTIONS ",_wgpuBufferRelease,_wgpuCreateInstance") endif() if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) @@ -312,13 +316,15 @@ else() target_compile_options(noexcep_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental) endif() target_link_options(onnxruntime_webassembly PRIVATE - --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js" + "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js\"" ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js") else () set(MAXIMUM_MEMORY "4294967296") target_link_options(onnxruntime_webassembly PRIVATE - --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js" + "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js.js\"" ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js") endif () target_link_options(onnxruntime_webassembly PRIVATE @@ -372,7 +378,6 @@ jsepDownload:_pp_") "SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'" ) endif () - set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) if (onnxruntime_USE_JSEP) # NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU @@ -382,10 +387,8 @@ jsepDownload:_pp_") target_compile_definitions(onnxruntime_webassembly PRIVATE USE_JSEP=1) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" - "SHELL:-s ASYNCIFY=1" - "SHELL:-s ASYNCIFY_STACK_SIZE=65536" ) - set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js") if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) target_link_options(onnxruntime_webassembly PRIVATE @@ -397,6 +400,20 @@ jsepDownload:_pp_") if (onnxruntime_USE_WEBGPU) target_compile_definitions(onnxruntime_webassembly PRIVATE USE_WEBGPU=1) + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/post-webgpu.js\"" + ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/post-webgpu.js") + endif() + + if (onnxruntime_USE_JSEP OR onnxruntime_USE_WEBGPU OR onnxruntime_USE_WEBNN) + # if any of the above is enabled, we need to use the asyncify library + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-async.js\"" + "SHELL:-s ASYNCIFY=1" + "SHELL:-s ASYNCIFY_STACK_SIZE=65536" + ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-async.js") endif() if (onnxruntime_EMSCRIPTEN_SETTINGS) @@ -458,6 +475,8 @@ jsepDownload:_pp_") ) endif() + set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS "${onnxruntime_webassembly_script_deps}") + set(target_name_list ort) if (onnxruntime_ENABLE_TRAINING_APIS) diff --git a/cmake/patches/dawn/dawn.patch b/cmake/patches/dawn/dawn.patch index 2f85d5ab473b5..b578b858eac59 100644 --- a/cmake/patches/dawn/dawn.patch +++ b/cmake/patches/dawn/dawn.patch @@ -18,7 +18,7 @@ index 6e8ae37593..633af91eef 100644 @@ -77,9 +77,17 @@ if (${DAWN_ENABLE_EMSCRIPTEN}) "${arg_UNPARSED_ARGUMENTS}") endif() - + + # since Emscripten 4.0.3, file gen_struct_info.py is moved to outside of directory maint. + if (EXISTS "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/gen_struct_info.py") + set(EM_GEN_STRUCT_INFO_SCRIPT "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/gen_struct_info.py") @@ -34,3 +34,114 @@ index 6e8ae37593..633af91eef 100644 -q "${EM_BUILD_GEN_DIR}/struct_info_webgpu.json" "-I=${EM_BUILD_GEN_DIR}/include" +diff --git a/src/emdawnwebgpu/README.md b/src/emdawnwebgpu/README.md +index efd6491cd6..8ebc5d28b6 100644 +--- a/src/emdawnwebgpu/README.md ++++ b/src/emdawnwebgpu/README.md +@@ -56,7 +56,7 @@ Set up the build directory using emcmake + mkdir out/cmake-wasm + cd out/cmake-wasm + +-# Make sure the path is to the source checkout of Emscripten, not emsdk's release. ++# If using Emscripten v4.0.2 or lower, make sure the path is to the source checkout of Emscripten, not emsdk's release. + emcmake cmake -GNinja -DDAWN_EMSCRIPTEN_TOOLCHAIN="path/to/emscripten" ../.. + + ninja +diff --git a/third_party/emdawnwebgpu/webgpu.cpp b/third_party/emdawnwebgpu/webgpu.cpp +index f1c5a7d50e..16f2495712 100644 +--- a/third_party/emdawnwebgpu/webgpu.cpp ++++ b/third_party/emdawnwebgpu/webgpu.cpp +@@ -131,7 +131,6 @@ class RefCounted : NonMovable { + bool Release() { + if (mRefCount.fetch_sub(1u, std::memory_order_release) == 1u) { + std::atomic_thread_fence(std::memory_order_acquire); +- emwgpuDelete(this); + return true; + } + return false; +@@ -234,6 +233,7 @@ class Ref { + static void Release(T value) { + if (value != nullptr && value->RefCounted::Release()) { + delete value; ++ emwgpuDelete(value); + } + } + +@@ -641,7 +641,8 @@ struct WGPUAdapterImpl final : public EventSource, public RefCounted { + struct WGPUBufferImpl final : public EventSource, + public RefCountedWithExternalCount { + public: +- WGPUBufferImpl(const EventSource* source, bool mappedAtCreation); ++ WGPUBufferImpl(const EventSource* source, bool mappedAtCreation, bool isExternal); ++ ~WGPUBufferImpl(); + + void Destroy(); + const void* GetConstMappedRange(size_t offset, size_t size); +@@ -671,6 +672,7 @@ struct WGPUBufferImpl final : public EventSource, + }; + MapRequest mPendingMapRequest; + WGPUBufferMapState mMapState; ++ bool mIsExternal; + }; + + struct WGPUQueueImpl final : public EventSource, public RefCounted { +@@ -1164,11 +1166,15 @@ WGPUAdapter emwgpuCreateAdapter(const EventSource* source) { + + WGPUBuffer emwgpuCreateBuffer(const EventSource* source, + bool mappedAtCreation = false) { +- return new WGPUBufferImpl(source, mappedAtCreation); ++ return new WGPUBufferImpl(source, mappedAtCreation, true); + } + + WGPUDevice emwgpuCreateDevice(const EventSource* source, WGPUQueue queue) { +- return new WGPUDeviceImpl(source, queue); ++ // This function is only called from JS via `importJsDevice()`, which ++ // needs to increment the external ref count to fix the behavior. ++ WGPUDeviceImpl* device = new WGPUDeviceImpl(source, queue); ++ device->AddExternalRef(); ++ return device; + } + + WGPUQueue emwgpuCreateQueue(const EventSource* source) { +@@ -1275,15 +1281,22 @@ WGPUAdapterImpl::WGPUAdapterImpl(const EventSource* source) + // WGPUBuffer implementations. + // ---------------------------------------------------------------------------- + +-WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, bool mappedAtCreation) ++WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, bool mappedAtCreation, bool isExternal) + : EventSource(source), + mMapState(mappedAtCreation ? WGPUBufferMapState_Mapped +- : WGPUBufferMapState_Unmapped) { ++ : WGPUBufferMapState_Unmapped), ++ mIsExternal(isExternal) { + if (mappedAtCreation) { + mPendingMapRequest = {kNullFutureId, WGPUMapMode_Write}; + } + } + ++WGPUBufferImpl::~WGPUBufferImpl() { ++ if (!mIsExternal) { ++ Destroy(); ++ } ++} ++ + void WGPUBufferImpl::Destroy() { + emwgpuBufferDestroy(this); + AbortPendingMap("Buffer was destroyed before mapping was resolved."); +@@ -1504,6 +1517,7 @@ WGPUFuture WGPUShaderModuleImpl::GetCompilationInfo( + void wgpu##Name##Release(WGPU##Name o) { \ + if (o->Release()) { \ + delete o; \ ++ emwgpuDelete(o); \ + } \ + } + WGPU_OBJECTS(DEFINE_WGPU_DEFAULT_ADDREF_RELEASE) +@@ -1638,7 +1652,7 @@ void wgpuBufferUnmap(WGPUBuffer buffer) { + + WGPUBuffer wgpuDeviceCreateBuffer(WGPUDevice device, + const WGPUBufferDescriptor* descriptor) { +- WGPUBuffer buffer = new WGPUBufferImpl(device, descriptor->mappedAtCreation); ++ WGPUBuffer buffer = new WGPUBufferImpl(device, descriptor->mappedAtCreation, false); + emwgpuDeviceCreateBuffer(device, descriptor, buffer); + return buffer; + } diff --git a/cmake/winml_sdk_helpers.cmake b/cmake/winml_sdk_helpers.cmake index 9241fcd060caf..ca657311b7f14 100644 --- a/cmake/winml_sdk_helpers.cmake +++ b/cmake/winml_sdk_helpers.cmake @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -cmake_minimum_required(VERSION 3.0) +cmake_minimum_required(VERSION 3.5) # utility function(convert_forward_slashes_to_back input output) diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj b/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj index f00a08a1a3595..b1452a64934c2 100644 --- a/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj +++ b/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj @@ -8,7 +8,7 @@ - + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs index 13117f23e8ef9..8916f11919cfe 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs @@ -25,7 +25,7 @@ internal class ManagedTypeProjection /// /// /// - /// OrtValye created accoding to the metadata + /// OrtValue created according to the metadata internal static OrtValue CreateProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata) { OrtValue result; @@ -191,4 +191,3 @@ private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata } } } - diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index d628b065ceaa7..b64a5c3e5a4a2 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -847,7 +847,7 @@ internal class NativeLib /// Creates an instance of OrtSession with provided parameters /// /// Native OrtEnv instance - /// Byte array correspoonding to the model + /// Byte array corresponding to the model /// Size of the model in bytes /// Native SessionOptions instance /// Native OrtPrepackedWeightsContainer instance @@ -1258,7 +1258,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// /// Native SessionOptions instance /// Name of the initializer - /// Native OrtValue instnce + /// Native OrtValue instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options, byte[] /*(const char*)*/ name, diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index c9a15de9ef897..2245ff5791feb 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -20,6 +20,7 @@ struct ComputeCapability; class KernelRegistry; struct KernelCreateInfo; class Node; +class GraphOptimizerRegistry; } // namespace onnxruntime #else #include @@ -129,10 +130,25 @@ class IExecutionProvider { and decide whether a node will be assigned to <*this> execution provider. For kernels registered in a kernel registry, `kernel_lookup` must be used to find a matching kernel for this EP. + + The graph_optimizer_registry is designed for enabling L2+ graph optimizations tailored for EPs. + These optimizations are applied after the graph partitioner assigns ComputeCapability to the EP + and before EP's "Compile" or fusion. + + Steps to use graph_optimizer_registry and create the optimization ComputeCapability: + 1. Lookup Optimizer: The EP calls provider bridge API to lookup pre-defined optimizer by name and get selection function. + - Example: g_host->GetOptimizerByName(optimizer_name, graph_optimizer_registry, selection_func) + 2. Run Selection Function: The EP executes the selection function to obtain the selection ComputeCapability. + - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization. + 3. Create Optimization ComputeCapability: The EP uses the selection ComputeCapability to create the optimization ComputeCapability. + 4. Return ComputeCapability: The EP returns the final ComputeCapability, with nodes_to_optimize set to the optimization ComputeCapability. + + Note: For more detailed implementations of using graph_optimizer_registry, please refer to TensorRT EP. */ virtual std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant = nullptr) const; /** diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 7798394b045dc..35b568e3f8e28 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -27,6 +27,7 @@ #include "core/common/span_utils.h" #include "core/common/status.h" #include "core/common/logging/logging.h" +#include "core/framework/ort_value.h" #include "core/framework/prepacked_weights_container.h" #include "core/graph/onnx_protobuf.h" #include "core/graph/basic_types.h" @@ -39,6 +40,9 @@ #include "core/graph/node_arg.h" #include "core/graph/ort_format_load_options.h" +// Type from Model Editor API in ORT C API so can't be in a namespace +struct OrtGraph; + namespace onnxruntime { class Graph; struct IndexedSubGraph; @@ -763,6 +767,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi */ bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const; + /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. + */ + bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const; + /** Gets all the initializer tensors in this Graph. */ const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; } @@ -1430,6 +1438,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph); + static Status LoadFromModelEditorApiModel(const OrtGraph& api_graph, + const Model& owning_model, + const std::unordered_map& domain_to_version, + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + bool strict_shape_type_inference, + const logging::Logger& logger, + std::unique_ptr& graph); + + Status UpdateUsingModelEditorApiModel(const OrtModel& api_model); + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const { return runtime_optimizations_; @@ -1630,7 +1648,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // Implementation for initializer replacement Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external); - std::vector CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, + template // range-initializer returning std::string + std::vector CreateNodeArgs(const StringRange& names, const ArgNameToTypeMap& name_to_type_map); void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const; @@ -1694,6 +1713,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return nodes_[node_index].get(); } + Status LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updating_existing_graph = false); + const Model& owning_model_; // GraphProto to store name, version, initializer. @@ -1708,6 +1729,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi InitializedTensorSet name_to_initial_tensor_; + // Initializers that are external to the Graph. + // e.g. created from existing memory using CreateTensorWithDataAndDeleterAsOrtValue in the ORT API. + // As we need to convert to TensorProto for the optimizers to work and keep the deleter information we store them + // in the Graph instance and retrieve during session state finalization. + std::unordered_map ortvalue_initializers_; + std::unordered_set, std::hash, std::equal_to> sparse_tensor_names_; @@ -1744,6 +1771,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // in some case, a fused sub-graph will happens multiple times in one model, we use a map // to store reusable-schema in lookup. InlinedHashMap> reusable_fused_schema_map_; + #endif // !defined(ORT_MINIMAL_BUILD) // Graph nodes. @@ -1806,7 +1834,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::unordered_map> node_arg_to_consumer_nodes_; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - const std::unordered_map domain_to_version_; + std::unordered_map domain_to_version_; // Model IR version. Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION}; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 9385e2f092e58..6a664d8be9c05 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -193,6 +193,12 @@ class GraphViewer { IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return graph_->GetSchemaRegistry(); } #endif + /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. + */ + bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const { + return graph_->GetOrtValueInitializer(name, value); + } + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer); GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info); diff --git a/include/onnxruntime/core/graph/indexed_sub_graph.h b/include/onnxruntime/core/graph/indexed_sub_graph.h index e457d3dcad1f1..088db79a7e005 100644 --- a/include/onnxruntime/core/graph/indexed_sub_graph.h +++ b/include/onnxruntime/core/graph/indexed_sub_graph.h @@ -72,6 +72,12 @@ struct IndexedSubGraph { return meta_def_.get(); } + /** Gets the mutable meta definition needed to represent this subgraph as a FunctionProto. + @returns MetaDef instance if it has been set. nullptr if not. */ + MetaDef* GetMutableMetaDef() { + return meta_def_.get(); + } + // Check if the accounting is enabled for the current EP bool IsAccountingEnabled() const { return resource_accountant != nullptr && diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 47e6389492f30..098de14bdfd61 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -305,6 +305,10 @@ ORT_RUNTIME_CLASS(OpAttr); ORT_RUNTIME_CLASS(Logger); ORT_RUNTIME_CLASS(ShapeInferContext); ORT_RUNTIME_CLASS(LoraAdapter); +ORT_RUNTIME_CLASS(ValueInfo); +ORT_RUNTIME_CLASS(Node); +ORT_RUNTIME_CLASS(Graph); +ORT_RUNTIME_CLASS(Model); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -665,6 +669,9 @@ typedef struct OrtApi OrtApi; struct OrtTrainingApi; typedef struct OrtTrainingApi OrtTrainingApi; +struct OrtModelEditorApi; +typedef struct OrtModelEditorApi OrtModelEditorApi; + /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -847,7 +854,8 @@ struct OrtApi { * * \snippet{doc} snippets.dox OrtStatus Return Value */ - ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, + ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); /** \brief Run the model in an ::OrtSession @@ -1340,6 +1348,8 @@ struct OrtApi { * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. * p_data is owned by caller. ReleaseValue won't release p_data. * + * If you wish to transfer ownership of p_data to ORT use CreateTensorWithDataAndDeleterAsOrtValue. + * * \param[in] info Memory description of where the p_data buffer resides (CPU vs GPU etc). * \param[in] p_data Pointer to the data buffer. * \param[in] p_data_len The number of bytes in the data buffer. @@ -1997,7 +2007,8 @@ struct OrtApi { /** \brief Get the value type from an ::OrtMapTypeInfo * * \param[in] map_type_info - * \param[out] type_info + * \param[out] type_info A copy of the OrtTypeInfo for the map value type. + * The user must free this value with ReleaseTypeInfo. * * \snippet{doc} snippets.dox OrtStatus Return Value */ @@ -2012,7 +2023,8 @@ struct OrtApi { * This is used by WinML to support model reflection APIs. * * \param[in] sequence_type_info - * \param[out] type_info + * \param[out] type_info A copy of the OrtTypeInfo for the sequence element type. + * The user must free this value with ReleaseTypeInfo. * * \snippet{doc} snippets.dox OrtStatus Return Value */ @@ -2887,7 +2899,8 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /** \brief Create session from memory with prepacked weights container @@ -2910,7 +2923,8 @@ struct OrtApi { */ ORT_API2_STATUS(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /// @} @@ -4293,8 +4307,8 @@ struct OrtApi { * specific type that is described by the returned ::OrtTypeInfo. * * \param[in] optional_type_info - * \param[out] out A pointer to the ::OrtTypeInfo for what the optional value could be. - * it is owned by OrtOptionalTypeInfo instance. + * \param[out] out A copy of ::OrtTypeInfo for what the optional value could be. + * The user must free this value with ReleaseTypeInfo. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -4786,6 +4800,75 @@ struct OrtApi { */ ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); + + /** \brief Release an OrtValueInfo instance if it was not added to an OrtGraph. + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(ValueInfo); + + /** \brief Release an OrtNode if it was not added to an OrtGraph. + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Node); + + /** \brief Release an OrtGraph. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Graph); + + /** \brief Release an OrtModel. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Model); + + /** \brief Get the value name from an OrtValueInfo instance. + * \param[in] value_info The OrtValueInfo instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_API2_STATUS(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); + + /** \brief Get the type information from an OrtValueInfo instance. + * \param[in] value_info The OrtValueInfo instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_API2_STATUS(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); + + /** \brief Get the Model Editor API instance + * + * Get the Model Editor API instance to create a new model or augment an existing model. + * + * \return Model Editor API struct + * + * \since Version 1.21. + */ + const OrtModelEditorApi*(ORT_API_CALL* GetModelEditorApi)(); + + /** \brief Create an OrtValue for a Tensor that uses pre-existing memory. + * + * ORT will take ownership of the memory and free it using the provided deleter when no longer in use. + * + * \param[in] deleter OrtAllocator instance that will be used to free the memory. + * Only the OrtAllocator:Info and OrtAllocator::Release functions are required. + * The OrtMemoryInfo returned by OrtAllocator::Info must match the location of p_data. + * \param[in] p_data Pointer to the memory that will be used by the Tensor. ORT will take ownership of the memory. + * \param[in] p_data_len Length of the memory in bytes. + * \param[in] shape Dimensions of the Tensor. All values should be > 0. + * \param[in] shape_len Number of dimensions in the shape array. + * \param[in] type Data type of the Tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); }; /* @@ -4900,6 +4983,400 @@ struct OrtCustomOp { void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); }; +/** + * ORT Model Editor API + */ + +/** + * \brief The OrtModelEditorApi struct provides functions to create or edit an ONNX model. + * + * See onnxruntime/test/shared_lib/test_model_editor_api.cc for example usage. + * + * \since Version 1.21. + */ +struct OrtModelEditorApi { + // Model building/editing requires a full build. We return nullptr from GetModelEditorApi if this is a minimal + // build, so it doesn't matter if there are no function pointers in this struct as a user will never get an + // OrtModelEditorApi instance. We do however need a dummy field to avoid empty struct warning. +#if defined(ORT_MINIMAL_BUILD) + const bool not_defined_in_this_build; +#else + /** \brief Create an OrtTypeInfo instance for a Tensor. + * + * Create an OrtTypeInfo instance for a Tensor to use as graph inputs/outputs with the Model Editor API. + * + * User can release `tensor_info` after creating the OrtTypeInfo. + * + * \param[in] tensor_info Tensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a SparseTensor. + * + * Create an OrtTypeInfo instance for a SparseTensor to use as graph inputs/outputs with the Model Editor API. + * + * User can release `tensor_info` after creating the OrtTypeInfo. + * + * \param[in] tensor_info SparseTensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a Map. + * + * Create an OrtTypeInfo instance for a Map to use as graph inputs/outputs with the Model Editor API. + * + * User can release `map_value_type` after creating the OrtTypeInfo. + * + * \param[in] map_key_type Key type for the map. + * \param[in] map_value_type Value type for the map. + * \param[out] TypeInfo instance for the map. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a Sequence. + * + * Create an OrtTypeInfo instance for a Sequence to use as graph inputs/outputs with the Model Editor API. + * + * User can release `sequence_type` after creating the OrtTypeInfo. + * + * \param[in] sequence_type Sequence type and shape information. + * \param[out] TypeInfo instance for the sequence. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for an Optional. + * + * Create an OrtTypeInfo instance for an Optional to use as graph inputs/outputs with the Model Editor API. + * + * User can release `contained_type` after creating the OrtTypeInfo. + * + * \param[in] tensor_info Tensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtValueInfo for use as an OrtGraph input or output. + * + * \param[in] name The name of the input or output. + * \param[in] type_info The type information for the input or output. The provided value is copied. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info); + + /** \brief Create an OrtNode to add to an OrtGraph. + * + * Create an OrtNode. + * + * Create attributes with CreateOpAttr. OrtOpAttr instances are copied. + * + * \param[in] operator_name The name of the operator. + * \param[in] domain_name The domain of the operator. Use an empty string for ONNX operators. + * \param[in] node_name The name of the node. + * \param[in] input_names The names of the inputs. + * \param[in] input_names_len The number of input names. + * \param[in] output_names The names of the outputs. + * \param[in] output_names_len The number of output names. + * \param[in] attributes The optional attributes of the node. + * \param[in] attribs_len The number of attributes. May be zero. + * \param[out] node The OrtNode instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, _In_ const char* domain_name, _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len, + _Outptr_ OrtNode** node); + + /** \brief Create an OrtGraph + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateGraph, _Outptr_ OrtGraph** graph); + + /** \brief Set the inputs for the OrtGraph. + * + * Set the graph inputs. This will replace any existing inputs with the new values. + * The OrtGraph takes ownership of the OrtValueInfo instances and you should NOT call ReleaseOrtValueInfo. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] inputs The input OrtValueInfo instances. + * \param[in] inputs_len The number of input OrtValueInfo instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SetGraphInputs, _Inout_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); + + /** \brief Set the outputs for the OrtGraph. + * + * Set the graph outputs. This will replace any existing outputs with the new values. + * The OrtGraph takes ownership of the OrtValueInfo instances provided and you should NOT call ReleaseOrtValueInfo. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] outputs The output OrtValueInfo instances. + * \param[in] outputs_len The number of output OrtValueInfo instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SetGraphOutputs, _Inout_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); + + /** \brief Add an initializer to the OrtGraph + * + * ORT will take ownership of the OrtValue and you should NOT call ReleaseOrtValue. + * + * Two options: + * + * Allocated memory: + * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. + * Set `data_is_external` to false. + * + * Pre-existing memory: + * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue + * with a tensor that contains a pointer to the existing data. + * Set `data_is_external` to true. + * + * The pointer must remain valid for the duration of the inference session. + * If using CreateTensorWithDataAsOrtValue you are responsible for freeing the memory after the inference session + * is released. + * If using CreateTensorWithDataAndDeleterAsOrtValue, ORT will free the memory using the provided deleter as + * soon as the OrtValue is no longer in use. + * + * NOTE: A tensor containing pre-existing memory MUST have 128 bytes of data or more. + * For smaller tensors use CreateTensorAsOrtValue. + * + * ONNX shape inferencing does not support external data. An initializer involved in shape inferencing is + * typically small (a single value or limited by the rank of a tensor) and uses less than 128 bytes of + * memory, so this limit acts as a simple catch-all rule to avoid issues. + * e.g. Reshape's `shape`, Clip's `min` and `max`, various ops `axes`. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] name The value name for the initializer. + * \param[in] tensor The OrtValue instance containing the tensor data. + * \param[in] data_is_external Set to true if the data is external and should not be copied. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddInitializerToGraph, _Inout_ OrtGraph* graph, _In_ const char* name, _In_ OrtValue* tensor, + bool data_is_external); + + /** \brief Add an OrtNode to an OrtGraph + * + * Add the node to the graph. The OrtGraph will take ownership of OrtNode and you should NOT call ReleaseOrtNode. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] node The OrtNode instance to add to the graph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddNodeToGraph, _Inout_ OrtGraph* graph, _In_ OrtNode* node); + + /** \brief Create an OrtModel. + * + * Create an OrtModel. + * + * This can be used to build a new model, or to augment an existing model. + * + * \param[in] domain_names The domain names for the model. + * If augmenting an existing model add additional domains if needed. + * \param[in] opset_versions The opset versions for the model. + * If augmenting an existing model add additional opset versions if needed. + * \param[in] opset_entries_len The number of domain_names and opset_versions entries. + * Domain and opset entries should be 1:1 + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model); + + /** \brief Add an OrtGraph to an OrtModel. + * + * Add the graph to a model. This should be called once when creating a new model. + * + * The OrtModel takes ownership of the OrtGraph and you should NOT call ReleaseOrtGraph. + * + * \param[in] model The OrtModel instance to update. + * \param[in] graph The OrtGraph instance to add to the model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddGraphToModel, _Inout_ OrtModel* model, _In_ OrtGraph* graph); + + /** \brief Create an OrtSession using the OrtModel. + * + * Create an inference session using the OrtModel instance. + * The OrtModel should have been populated with an OrtGraph containing nodes and initializers, and SetGraphInputs + * and SetGraphOutputs must have been called. + * This will validate the model, run optimizers, and prepare the session for inferencing. + * + * ReleaseOrtModel must be called to free the OrtModel after session creation. + * + * \param[in] env The OrtEnv instance. + * \param[in] model The OrtModel instance. + * \param[in] options The OrtSessionOptions instance. + * \param[out] out The OrtSession instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession to augment an existing model. + * + * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. + * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the + * model is finalized. + * + * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. + * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. + * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made + * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. + * + * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the + * session for inferencing by calling FinalizeModelEditorSession. + * + * \param{in} env The OrtEnv instance. + * \param{in} model_path The path to the existing ONNX model to augment. + * \param{in} options The OrtSessionOptions instance. + * \param{out} out The created OrtSession instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModelEditorSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession to augment an existing model. + * + * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. + * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the + * model is finalized. + * + * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. + * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. + * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made + * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. + * + * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the + * session for inferencing by calling FinalizeModelEditorSession. + * + * \param{in} env The OrtEnv instance. + * \param{in} model_data The model data for the existing model to augment. + * \param{in} model_data_length The length of the model data. + * \param{in} options The OrtSessionOptions instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + + /** \brief Query the session for the opset version of a domain. + * + * When using the Model Editor API to augment a model, any new nodes must conform to the opset version of the + * original model. To do that the user must be able to discover that opset version. + * + * \param[in] session OrtSession to query + * \param[in] domain Domain to query. The ONNX domain is an empty string. + * \param[out] opset The opset version of the domain. + * + * \snippet{doc} snippets.dox OrtStatus Return Value. Returns an error if the domain is not used in the model. + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, _Out_ int* opset); + + /** \brief Apply changes to augment the ONNX model in a session created using CreateModelEditorSession[FromArray] + * + * Adds new nodes and updates graph inputs/outputs using `model` to augment the original ONNX model in the session. + * All changes will be validated. + * Call FinalizeModelEditorSession to prepare the session for inferencing. + * + * Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel. + * i.e. you don't need to call SetGraphInputs/SetGraphOutputs if they are unchanged. + * + * ReleaseOrtModel must be called to free the OrtModel after it is applied to the session. + * + * \param[in] session OrtSession to update. Session must have been created using CreateModelEditorSession[FromArray]. + * \param[in] model OrtModel containing new nodes, new initializers, and updated graph input and/or output info. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(ApplyModelToModelEditorSession, _Inout_ OrtSession* session, _In_ OrtModel* model); + + /** \brief Finalize the Model Editor session that was created using CreateModelEditorSession[FromArray]. + * + * Finalize the Model Editor session that augmented an ONNX model by adding new nodes. + * This will run optimizers and prepare the session for inferencing. + * + * \param[in] session OrtSession to finalize. Session must have been created using CreateModelEditorSession[FromArray]. + * \param[in] options OrtSessionOptions to use for the session. + * \param[in] Optional prepacked_weights_container OrtPrepackedWeightsContainer to use for the session. + Set to nullptr if not used. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(FinalizeModelEditorSession, _Inout_ OrtSession* session, _In_ const OrtSessionOptions* options, + _In_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container); +#endif // !defined(ORT_MINIMAL_BUILD) +}; + /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 123ef98901003..979b478e2fbb4 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -26,16 +26,17 @@ #include "onnxruntime_c_api.h" #include "onnxruntime_float16.h" +#include #include #include -#include #include #include #include -#include +#include #include #include -#include +#include +#include #ifdef ORT_NO_EXCEPTIONS #include @@ -120,7 +121,7 @@ const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); #endif #endif -/// This returns a reference to the OrtApi interface in use +/// This returns a reference to the ORT C API. inline const OrtApi& GetApi() noexcept { return *Global::api_; } /// @@ -143,6 +144,20 @@ std::string GetBuildInfoString(); /// vector of strings std::vector GetAvailableProviders(); +/// +/// This returns a reference to the ORT C Model Editor API. Used if building or augmenting a model at runtime. +/// +/// ORT C Model Editor API reference +inline const OrtModelEditorApi& GetModelEditorApi() { + auto* api = GetApi().GetModelEditorApi(); + if (api == nullptr) { + // minimal build + ORT_CXX_API_THROW("Model Editor API is not available in this build", ORT_FAIL); + } + + return *api; +} + /** \brief IEEE 754 half-precision floating point data type * * \details This struct is used for converting float to float16 and back @@ -523,6 +538,10 @@ ORT_DEFINE_RELEASE(Status); ORT_DEFINE_RELEASE(OpAttr); ORT_DEFINE_RELEASE(Op); ORT_DEFINE_RELEASE(KernelInfo); +ORT_DEFINE_RELEASE(ValueInfo); +ORT_DEFINE_RELEASE(Node); +ORT_DEFINE_RELEASE(Graph); +ORT_DEFINE_RELEASE(Model); #undef ORT_DEFINE_RELEASE @@ -559,7 +578,9 @@ struct Base { constexpr Base() = default; constexpr explicit Base(contained_type* p) noexcept : p_{p} {} - ~Base() { OrtRelease(p_); } + ~Base() { + OrtRelease(p_); + } Base(const Base&) = delete; Base& operator=(const Base&) = delete; @@ -635,9 +656,13 @@ struct AllocatedFree { struct AllocatorWithDefaultOptions; struct Env; +struct Graph; +struct Model; +struct Node; +struct ModelMetadata; struct TypeInfo; struct Value; -struct ModelMetadata; +struct ValueInfo; /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators * and release them at the end of the scope. The lifespan of the given allocator @@ -1051,6 +1076,10 @@ struct ConstSessionImpl : Base { size_t GetOutputCount() const; ///< Returns the number of model outputs size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden + std::vector GetInputNames() const; + std::vector GetOutputNames() const; + std::vector GetOverridableInitializerNames() const; + /** \brief Returns a copy of input name at the specified index. * * \param index must less than the value returned by GetInputCount() @@ -1084,6 +1113,12 @@ struct ConstSessionImpl : Base { TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo + + int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain + + // Will move before checkin if that's the case. + std::vector GetInputs() const; + std::vector GetOutputs() const; }; template @@ -1161,6 +1196,9 @@ struct SessionImpl : ConstSessionImpl { * \param[in] kv_len Number of elements in the keys and values arrays */ void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len); + + void FinalizeModelEditorSession(const Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); }; } // namespace detail @@ -1172,13 +1210,34 @@ using UnownedSession = detail::SessionImpl>; * */ struct Session : detail::SessionImpl { - explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used - Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession + /// Create an empty Session object, must be assigned a valid one to be used. Wraps OrtApi::CreateSession + explicit Session(std::nullptr_t) {} + explicit Session(OrtSession* p) : SessionImpl{p} {} ///< C API Interop + + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); + + /// Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer - Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray + OrtPrepackedWeightsContainer* prepacked_weights_container); + + /// Wraps OrtApi::CreateSessionFromArray + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); + + /// Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer + OrtPrepackedWeightsContainer* prepacked_weights_container); + +#if !defined(ORT_MINIMAL_BUILD) + /// Wraps OrtModelEditorApi::CreateSessionFromModel + Session(const Env& env, const Model& model, const SessionOptions& options); + + /// Wraps OrtModelEditorApi::CreateModelEditorSession + static Session CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); + + /// Wraps OrtModelEditorApi::CreateModelEditorSession + static Session CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options); +#endif // !defined(ORT_MINIMAL_BUILD) ConstSession GetConst() const { return ConstSession{this->p_}; } UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } @@ -1210,7 +1269,7 @@ using ConstMemoryInfo = detail::MemoryInfoImpl { static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created - explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C Api + explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } }; @@ -1233,6 +1292,7 @@ struct TensorTypeAndShapeInfoImpl : Base { [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions + std::vector GetSymbolicDimensions() const; std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape }; @@ -1248,8 +1308,18 @@ struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl; using Base::Base; - explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used - explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API + /// Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} + /// Used for interop with the C API + explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} + + // Create a TensorTypeAndShapeInfo object with the specified element type and dimensions + // symbolic_dims are optional, but should be 1:1 with dims. + // The value in symbolic_dims will be used for all entries in dims that are -1. + explicit TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, + const std::vector& dims, + const std::vector* symbolic_dims = nullptr); + ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } }; @@ -1344,9 +1414,18 @@ struct TypeInfo : detail::TypeInfoImpl { using Base = detail::TypeInfoImpl; using Base::Base; - explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used + /// Create an empty TypeInfo object, must be assigned a valid one to be used + explicit TypeInfo(std::nullptr_t) {} explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop +#if !defined(ORT_MINIMAL_BUILD) + static TypeInfo CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_info); + static TypeInfo CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_info); + static TypeInfo CreateSequenceTypeInfo(ConstTypeInfo sequence_type); + static TypeInfo CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type); + static TypeInfo CreateOptionalTypeInfo(ConstTypeInfo contained_type); +#endif // !defined(ORT_MINIMAL_BUILD) + ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; } }; @@ -1701,7 +1780,8 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. */ template - static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); + static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, + const int64_t* shape, size_t shape_len); /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. * @@ -1712,11 +1792,25 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ - static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAndDeleterAsOrtValue. + * + * \param deleter OrtAllocator that will be used to free the buffer when no longer required. + * \param p_data Pointer to the data buffer. + * \param p_data_byte_count The number of bytes in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. - * This overload will allocate the buffer for the tensor according to the supplied shape and data type. + * This overload will allocate the buffer for the tensor according to the supplied shape and data type. * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. * The input data would need to be copied into the allocated buffer. * This API is not suitable for strings. @@ -1740,7 +1834,8 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a Map Onnx type representation. * The API would ref-count the supplied OrtValues and they will be released @@ -2437,6 +2532,9 @@ struct CustomOpBase : OrtCustomOp { return std::vector{}; } + // Ort::CustomOpBase derived class should provide the following static method with the type/shape inferencing + // implementation if needed: + // static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context) template decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) { OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { @@ -2459,6 +2557,129 @@ struct CustomOpBase : OrtCustomOp { int end_ver_ = MAX_CUSTOM_OP_END_VER; }; -} // namespace Ort +namespace detail { +template +struct ValueInfoImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + std::string Name() const; + ConstTypeInfo TypeInfo() const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstValueInfo = detail::ValueInfoImpl>; +/** \brief Wrapper around ::OrtValueInfo + * + */ +struct ValueInfo : detail::ValueInfoImpl { + explicit ValueInfo(std::nullptr_t) {} ///< No instance is created + /// Take ownership of a pointer created by C API + explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl{p} {} + + // Create ValueInfo for a tensor + explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info); + + ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; } +}; + +namespace detail { +template +struct NodeImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; +}; +} // namespace detail + +/** \brief Wrapper around ::OrtNode + * + */ +struct Node : detail::NodeImpl { + explicit Node(std::nullptr_t) {} ///< No instance is created + explicit Node(OrtNode* p) : NodeImpl{p} {} ///< Take ownership of a pointer created by C API + +#if !defined(ORT_MINIMAL_BUILD) + Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names); + + /// + /// Wraps CreateNode. Node takes ownership of attributes on success and updates the OpAttr in `attributes` to do so. + /// + Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes); + + private: + static void Init(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes, + OrtNode*& node); +#endif // !defined(ORT_MINIMAL_BUILD) +}; + +namespace detail { +template +struct GraphImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + +#if !defined(ORT_MINIMAL_BUILD) + void SetInputs(std::vector& inputs); + void SetOutputs(std::vector& outputs); + void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value + void AddNode(Node& node); // Graph takes ownership of Node +#endif // !defined(ORT_MINIMAL_BUILD) +}; +} // namespace detail + +/** \brief Wrapper around ::OrtGraph + * + */ +struct Graph : detail::GraphImpl { + explicit Graph(std::nullptr_t) {} ///< No instance is created + explicit Graph(OrtGraph* p) : GraphImpl{p} {} ///< Take ownership of a pointer created by C API +#if !defined(ORT_MINIMAL_BUILD) + Graph(); +#endif +}; + +namespace detail { +template +struct ModelImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + +#if !defined(ORT_MINIMAL_BUILD) + void AddGraph(Graph& graph); +#endif +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstModel = detail::ModelImpl>; + +/** \brief Wrapper around ::OrtModel + * + */ +struct Model : detail::ModelImpl { + using DomainOpsetPair = std::pair; + + explicit Model(std::nullptr_t) {} ///< No instance is created + explicit Model(OrtModel* p) : ModelImpl{p} {} ///< Take ownership of a pointer created by C API + +#if !defined(ORT_MINIMAL_BUILD) + explicit Model(const std::vector& opsets); +#endif + + ConstModel GetConst() const { return ConstModel{this->p_}; } +}; +} // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 3aeb9412f350e..48c5e52e33c53 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -10,7 +10,9 @@ #include #include #include +#include #include +#include // Convert OrtStatus to Ort::Status and return // instead of throwing @@ -995,6 +997,59 @@ inline size_t ConstSessionImpl::GetOverridableInitializerCount() const { return out; } +template +inline std::vector ConstSessionImpl::GetInputNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_inputs = GetInputCount(); + std::vector input_names; + input_names.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name)); + input_names.push_back(name); + allocator.Free(name); + } + + return input_names; +} + +template +inline std::vector ConstSessionImpl::GetOutputNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_inputs = GetOutputCount(); + std::vector output_names; + output_names.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name)); + output_names.push_back(name); + allocator.Free(name); + } + + return output_names; +} + +template +inline std::vector ConstSessionImpl::GetOverridableInitializerNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_initializers = GetOverridableInitializerCount(); + std::vector initializer_names; + initializer_names.reserve(num_initializers); + + for (size_t i = 0; i < num_initializers; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name)); + initializer_names.push_back(name); + } + + return initializer_names; +} + template inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { char* out; @@ -1051,6 +1106,45 @@ inline TypeInfo ConstSessionImpl::GetOverridableInitializerTypeInfo(size_t in return TypeInfo{out}; } +#if !defined(ORT_MINIMAL_BUILD) +template +inline int ConstSessionImpl::GetOpset(const std::string& domain) const { + int opset; + ThrowOnError(GetModelEditorApi().SessionGetOpsetForDomain(this->p_, domain.c_str(), &opset)); + return opset; +} +#endif // !defined(ORT_MINIMAL_BUILD) + +template +std::vector ConstSessionImpl::GetInputs() const { + const std::vector input_names = GetInputNames(); + + std::vector inputs; + inputs.reserve(input_names.size()); + + for (size_t i = 0; i < input_names.size(); ++i) { + auto type_info = GetInputTypeInfo(i); + inputs.emplace_back(ValueInfo{input_names[i], type_info.GetConst()}); + } + + return inputs; +} + +template +std::vector ConstSessionImpl::GetOutputs() const { + const std::vector output_names = GetOutputNames(); + + std::vector outputs; + outputs.reserve(output_names.size()); + + for (size_t i = 0; i < output_names.size(); ++i) { + auto type_info = GetOutputTypeInfo(i); + outputs.emplace_back(ValueInfo{output_names[i], type_info.GetConst()}); + } + + return outputs; +} + template inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, size_t output_count) { @@ -1098,6 +1192,15 @@ inline void SessionImpl::SetEpDynamicOptions(const char* const* keys, const c ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len)); } +#if !defined(ORT_MINIMAL_BUILD) +template +inline void SessionImpl::FinalizeModelEditorSession(const Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container) { + ThrowOnError(GetModelEditorApi().ApplyModelToModelEditorSession(this->p_, model)); + ThrowOnError(GetModelEditorApi().FinalizeModelEditorSession(this->p_, options, prepacked_weights_container)); +} +#endif // #if !defined(ORT_MINIMAL_BUILD) + } // namespace detail inline SessionOptions::SessionOptions() { @@ -1144,6 +1247,32 @@ inline Session::Session(const Env& env, const void* model_data, size_t model_dat prepacked_weights_container, &this->p_)); } +#if !defined(ORT_MINIMAL_BUILD) +inline Session::Session(const Env& env, const Model& model, const SessionOptions& options) { + ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_)); +} + +// static +inline Session Session::CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path, + const SessionOptions& options) { + OrtSession* session = nullptr; + ThrowOnError(GetModelEditorApi().CreateModelEditorSession(env, model_path, options, &session)); + return Session(session); +} + +// static +inline Session Session::CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options) { + OrtSession* session = nullptr; + ThrowOnError(GetModelEditorApi().CreateModelEditorSessionFromArray(env, model_data, model_data_length, options, + &session)); + return Session(session); +} + +void FinalizeModelEditorSession(const Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); +#endif // #if !defined(ORT_MINIMAL_BUILD) + inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); @@ -1211,6 +1340,59 @@ inline int64_t ModelMetadata::GetVersion() const { return out; } +inline TensorTypeAndShapeInfo::TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, + const std::vector& dims, + const std::vector* symbolic_dims) { + ThrowOnError(GetApi().CreateTensorTypeAndShapeInfo(&p_)); + ThrowOnError(GetApi().SetTensorElementType(p_, element_type)); + ThrowOnError(GetApi().SetDimensions(p_, dims.data(), dims.size())); + + if (symbolic_dims) { + std::vector symbolic_dims_cstr; + symbolic_dims_cstr.reserve(symbolic_dims->size()); + std::transform(symbolic_dims->begin(), symbolic_dims->end(), std::back_inserter(symbolic_dims_cstr), + [](const std::string& s) { return s.c_str(); }); + ThrowOnError(GetApi().SetSymbolicDimensions(p_, symbolic_dims_cstr.data(), symbolic_dims_cstr.size())); + } +} + +#if !defined(ORT_MINIMAL_BUILD) +// static +inline TypeInfo TypeInfo::CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_type_and_shape_info) { + OrtTypeInfo* output = nullptr; + ThrowOnError(GetModelEditorApi().CreateTensorTypeInfo(tensor_type_and_shape_info, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_type_and_shape_info) { + OrtTypeInfo* output = nullptr; + ThrowOnError(GetModelEditorApi().CreateSparseTensorTypeInfo(sparse_tensor_type_and_shape_info, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateSequenceTypeInfo(ConstTypeInfo sequence_type) { + OrtTypeInfo* output; + ThrowOnError(GetModelEditorApi().CreateSequenceTypeInfo(sequence_type, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type) { + OrtTypeInfo* output; + ThrowOnError(GetModelEditorApi().CreateMapTypeInfo(key_type, value_type, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateOptionalTypeInfo(ConstTypeInfo contained_type) { + OrtTypeInfo* output; + ThrowOnError(GetModelEditorApi().CreateOptionalTypeInfo(contained_type, &output)); + return TypeInfo{output}; +} +#endif // #if !defined(ORT_MINIMAL_BUILD) + namespace detail { template @@ -1244,9 +1426,16 @@ inline void TensorTypeAndShapeInfoImpl::GetSymbolicDimensions(const char** va ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count)); } +template +inline std::vector TensorTypeAndShapeInfoImpl::GetSymbolicDimensions() const { + std::vector out(GetDimensionsCount(), nullptr); + ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, out.data(), out.size())); + return out; +} + template inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { - std::vector out(GetDimensionsCount(), 0); + std::vector out(GetDimensionsCount(), -1); ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size())); return out; } @@ -1560,23 +1749,35 @@ void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_inf } // namespace detail template -inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { +inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, + const int64_t* shape, size_t shape_len) { return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); } -inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, +inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); return Value{out}; } +inline Value Value::CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorWithDataAndDeleterAsOrtValue(deleter, p_data, p_data_byte_count, + shape, shape_len, type, &out)); + return Value{out}; +} + template inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); } -inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); return Value{out}; @@ -1594,7 +1795,8 @@ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& values_shape, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, - values_shape.shape, values_shape.shape_len, type, &out)); + values_shape.shape, values_shape.shape_len, type, + &out)); return Value{out}; } @@ -2167,4 +2369,142 @@ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) con return attr_hdl; } +namespace detail { +inline std::vector StringsToCharPtrs(const std::vector& strings) { + std::vector ptrs; + ptrs.reserve(strings.size()); + std::transform(strings.begin(), strings.end(), std::back_inserter(ptrs), + [](const std::string& s) { return s.c_str(); }); + + return ptrs; +} +} // namespace detail + +#if !defined(ORT_MINIMAL_BUILD) +// static +inline void Node::Init(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes, + OrtNode*& node) { + auto inputs = detail::StringsToCharPtrs(input_names); + auto outputs = detail::StringsToCharPtrs(output_names); + + std::vector attributes_ptrs; + attributes_ptrs.reserve(attributes.size()); + std::transform(attributes.begin(), attributes.end(), std::back_inserter(attributes_ptrs), + [](OpAttr& attr) -> OrtOpAttr* { return attr; }); + + ThrowOnError(GetModelEditorApi().CreateNode(operator_name.c_str(), operator_domain.c_str(), node_name.c_str(), + inputs.data(), inputs.size(), + outputs.data(), outputs.size(), + attributes_ptrs.data(), attributes_ptrs.size(), + &node)); + + // Node now owns the attributes + std::for_each(attributes.begin(), attributes.end(), [](OpAttr& attr) { attr.release(); }); +} + +inline Node::Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes) { + Init(operator_name, operator_domain, node_name, input_names, output_names, attributes, p_); +} + +inline Node::Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names) { + std::vector empty_attributes; + Init(operator_name, operator_domain, node_name, input_names, output_names, empty_attributes, p_); +} + +inline Graph::Graph() { + ThrowOnError(GetModelEditorApi().CreateGraph(&p_)); +} + +inline Model::Model(const std::vector& opsets) { + std::vector domains; + std::vector versions; + domains.reserve(opsets.size()); + versions.reserve(opsets.size()); + + for (const auto& pair : opsets) { + domains.push_back(pair.first.c_str()); + versions.push_back(pair.second); + } + + ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); +} + +inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { + ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_)); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +namespace detail { +template <> +inline std::string ValueInfoImpl::Name() const { + const char* name = nullptr; + ThrowOnError(GetApi().GetValueInfoName(this->p_, &name)); + return name; +} + +template <> +inline ConstTypeInfo ValueInfoImpl::TypeInfo() const { + const OrtTypeInfo* type_info = nullptr; + ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info)); + return ConstTypeInfo{type_info}; +} + +#if !defined(ORT_MINIMAL_BUILD) +template <> +inline void GraphImpl::SetInputs(std::vector& inputs) { + std::vector inputs_ptrs; + inputs_ptrs.reserve(inputs.size()); + std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs), + [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); + + ThrowOnError(GetModelEditorApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size())); + + // Graph now owns the inputs + std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); }); +} + +template <> +inline void GraphImpl::SetOutputs(std::vector& outputs) { + std::vector outputs_ptrs; + outputs_ptrs.reserve(outputs.size()); + std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs), + [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); + + ThrowOnError(GetModelEditorApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size())); + + // Graph now owns the outputs + std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); }); +} + +template <> +inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { + // Graph takes ownership of `initializer` + ThrowOnError(GetModelEditorApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external)); +} + +template <> +inline void GraphImpl::AddNode(Node& node) { + // Graph takes ownership of `node` + ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release())); +} + +template <> +inline void ModelImpl::AddGraph(Graph& graph) { + // Model takes ownership of `graph` + ThrowOnError(GetModelEditorApi().AddGraphToModel(p_, graph.release())); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +} // namespace detail } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 117a2cdabca2f..af1f9c04b2831 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -315,9 +315,12 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed // in case user need to merge/connect multiple EPContext nodes in one model static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix"; -// Share EP related resources across EPs +// Share EP related resources across sessions static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts"; +// Stop to share EP related resources across sessions from then on +static const char* const kOrtSessionOptionStopShareEpContexts = "ep.stop_share_ep_contexts"; + // Use this config when dumping EP context model with an external initializers file // All initializers will be inside the external data file if specified, otherwise all in Onnx file static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFileName = diff --git a/js/build_webgpu.bat b/js/build_webgpu.bat new file mode 100644 index 0000000000000..95413509e701d --- /dev/null +++ b/js/build_webgpu.bat @@ -0,0 +1,79 @@ +@echo off + +rem build_webgpu.bat --- build onnxruntime-web with WebGPU EP +rem +rem Usage: +rem build_webgpu.bat config [clean] +rem +rem Options: +rem config Build configuration, "d" or "r" +rem clean Perform a clean build, "clean" or empty + +setlocal enabledelayedexpansion + +set ROOT=%~dp0..\ +set BUILD_DIR=%ROOT%build_webgpu + +:arg1 +if ["%~1"]==["d"] ( + set CONFIG=Debug + set CONFIG_EXTRA_FLAG= + @rem --enable_wasm_profiling --wasm_run_tests_in_browser + @rem --cmake_extra_defines onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL=1 + @rem --enable_wasm_debug_info + goto :arg2 +) +if ["%~1"]==["r"] ( + set CONFIG=Release + set CONFIG_EXTRA_FLAG= + @rem --enable_wasm_api_exception_catching --disable_rtti + goto :arg2 +) +echo Invalid configuration "%~1", must be "d"(Debug) or "r"(Release) +exit /b 1 + +:arg2 +if ["%~2"]==["clean"] ( + goto :clean +) +if not exist "%ROOT%js\web\dist" ( + goto :npm_ci +) + +goto :build_wasm + +:clean +if exist "%BUILD_DIR%" ( + rd /s /q %BUILD_DIR% +) + +pushd %ROOT% +git submodule sync --recursive +git submodule update --init --recursive +popd + +:npm_ci +pushd %ROOT%js +call npm ci +popd +pushd %ROOT%js\common +call npm ci +popd +pushd %ROOT%js\web +call npm ci +call npm run pull:wasm +popd + +:build_wasm + +set PATH=C:\Program Files\Git\usr\bin;%PATH% + +call %ROOT%build.bat --config %CONFIG% %CONFIG_EXTRA_FLAG% --skip_submodule_sync --build_wasm --target onnxruntime_webassembly --skip_tests^ + --enable_wasm_simd --enable_wasm_threads --use_jsep --use_webnn --use_webgpu --build_dir %BUILD_DIR% + +IF NOT "%ERRORLEVEL%" == "0" ( + exit /b %ERRORLEVEL% +) + +copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.wasm %ROOT%js\web\dist\ +copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.mjs %ROOT%js\web\dist\ diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index 14dbdca707220..58f4cc6281b09 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -44,12 +44,6 @@ export const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map { isTypedArrayChecked = true; const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from; const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from; + + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array = (globalThis as any).Float16Array; const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from; if (isBigInt64ArrayAvailable) { diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 8feb8d7205fa1..2c54bdbfb6874 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -261,6 +261,13 @@ export class Tensor implements TensorInterface { } else { throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`); } + } else if (arg0 === 'float16' && arg1 instanceof Uint16Array && typedArrayConstructor !== Uint16Array) { + // when Float16Array is available and data is of type Uint16Array. + // We allow Uint16Array to be passed in as data for 'float16' tensor until Float16Array is generally + // supported in JavaScript environment. + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + data = new (globalThis as any).Float16Array(arg1.buffer, arg1.byteOffset, arg1.length); } else { throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`); } diff --git a/js/common/package.json b/js/common/package.json index 3d8d3f6533cfe..2d331bb42e4c7 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -15,7 +15,8 @@ "build": "node ./build.js", "prepare": "npm run build", "pretest": "tsc --build ./test", - "test": "mocha ./test/**/*.js --timeout 30000" + "test": "mocha \"./test/**/*.js\" --timeout 30000", + "test:f16": "mocha -n js-float16array \"./test/**/*.js\" --timeout 30000" }, "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/common/test/unit-tests/common.ts b/js/common/test/unit-tests/common.ts index 0a6e4e5dd6ebd..bbbceed605bd4 100644 --- a/js/common/test/unit-tests/common.ts +++ b/js/common/test/unit-tests/common.ts @@ -29,9 +29,10 @@ export const NUMBER_COMPATIBLE_NUMERICAL_TYPES = [ export const BIGINT_TYPES = [['int64', BigInt64Array, true] as const, ['uint64', BigUint64Array, true] as const]; /** - * float16 type, data represented by Uint16Array + * float16 type, data represented by Uint16Array/Float16Array */ -export const FLOAT16_TYPE = ['float16', Uint16Array, false] as const; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export const FLOAT16_TYPE = ['float16', (globalThis as any).Float16Array ?? Uint16Array, false] as const; /** * A list of all numerical types. diff --git a/js/common/test/unit-tests/tensor/constructor-f16.ts b/js/common/test/unit-tests/tensor/constructor-f16.ts new file mode 100644 index 0000000000000..38c6ac037c5f9 --- /dev/null +++ b/js/common/test/unit-tests/tensor/constructor-f16.ts @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import assert from 'assert/strict'; +import { Tensor } from 'onnxruntime-common'; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const globalF16 = (globalThis as any).Float16Array; + +(globalF16 ? describe : describe.skip)('Tensor Constructor Tests - check type float16 (Float16Array available)', () => { + it("[float16] new Tensor('float16', numbers, dims): allow number array when Float16Array is available", () => { + const tensor = new Tensor('float16', [1, 2, 3, 4], [2, 2]); + assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); + assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'"); + assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1'); + assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2'); + assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3'); + assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4'); + assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4'); + }); + + it("[float16] new Tensor('float16', float16array, dims): allow Float16Array when Float16Array is available", () => { + const tensor = new Tensor('float16', new globalF16([1, 2, 3, 4]), [2, 2]); + assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); + assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'"); + assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1'); + assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2'); + assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3'); + assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4'); + assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4'); + }); + + it("[float16] new Tensor('float16', uint16array, dims): allow Uint16Array when Float16Array is available", () => { + const tensor = new Tensor('float16', new Uint16Array([15360, 16384, 16896, 17408]), [2, 2]); + assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); + assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'"); + assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1'); + assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2'); + assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3'); + assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4'); + assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4'); + }); +}); + +(globalF16 ? describe.skip : describe)( + 'Tensor Constructor Tests - check type float16 (Float16Array not available)', + () => { + it( + "[float16] new Tensor('float16', numbers, dims): " + + "expect to throw because it's not allowed to construct 'float16' tensor from number array", + () => { + assert.throws(() => new Tensor('float16', [1, 2, 3, 4], [2, 2]), TypeError); + }, + ); + + it("[float16] new Tensor('float16', uint16array, dims): allow Uint16Array", () => { + const tensor = new Tensor('float16', new Uint16Array([15360, 16384, 16896, 17408]), [2, 2]); + assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); + assert(tensor.data instanceof Uint16Array, "tensor.data should be an instance of 'Uint16Array'"); + }); + }, +); diff --git a/js/common/test/unit-tests/tensor/constructor-type.ts b/js/common/test/unit-tests/tensor/constructor-type.ts index 02390800e8611..d86e18ba744b8 100644 --- a/js/common/test/unit-tests/tensor/constructor-type.ts +++ b/js/common/test/unit-tests/tensor/constructor-type.ts @@ -105,14 +105,6 @@ describe('Tensor Constructor Tests - check types', () => { assert(tensor.data instanceof Uint8Array, "tensor.data should be an instance of 'Uint8Array'"); }); - it( - "[float16] new Tensor('float16', numbers, dims): " + - "expect to throw because it's not allowed to construct 'float16' tensor from number array", - () => { - assert.throws(() => new Tensor('float16', [1, 2, 3, 4], [2, 2]), TypeError); - }, - ); - it("[badtype] new Tensor('a', numbers, dims): expect to throw because 'a' is an invalid type", () => { assert.throws(() => new TensorAny('a', [1, 2, 3, 4], [2, 2]), TypeError); }); diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 59f64a3179605..83a52ebaefe05 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -40,6 +40,13 @@ interface BuildDefinitions { */ readonly ENABLE_BUNDLE_WASM_JS: boolean; + /** + * defines whether to use WebGPU EP instead of JSEP for WebGPU backend. + * + * This flag requires the corresponding WebAssembly artifact to be built with `--use_webgpu` flag. + */ + readonly USE_WEBGPU_EP: boolean; + // #endregion // #region Build definitions for ESM diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a0010df4643a4..413e89111740e 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -13,7 +13,6 @@ import { ProgramManager } from './webgpu/program-manager'; import { AdapterInfo, ComputeContext, - DeviceInfo, GpuArchitecture, GpuData, GpuVendor, @@ -135,26 +134,6 @@ class AdapterInfoImpl implements AdapterInfo { } } -class DeviceInfoImpl implements DeviceInfo { - readonly subgroupsSupported: boolean; - readonly subgroupsF16Supported: boolean; - readonly subgroupSizeRange?: readonly [number, number]; - - constructor(device: GPUDevice) { - this.subgroupsSupported = device.features.has('subgroups' as GPUFeatureName); - this.subgroupsF16Supported = device.features.has('subgroups' as GPUFeatureName); - // Currently subgroups feature is still experimental and size attributes are not in the WebGPU IDL, so we have to - // workaround the IDL type checks. - // TODO: clean this after subgroups feature is settled in IDL. - const deviceSubgroupsLimits = device.limits as { minSubgroupSize?: number; maxSubgroupSize?: number }; - if (!this.subgroupsSupported || !deviceSubgroupsLimits.minSubgroupSize || !deviceSubgroupsLimits.maxSubgroupSize) { - this.subgroupSizeRange = undefined; - } else { - this.subgroupSizeRange = [deviceSubgroupsLimits.minSubgroupSize, deviceSubgroupsLimits.maxSubgroupSize]; - } - } -} - /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. @@ -162,7 +141,6 @@ class DeviceInfoImpl implements DeviceInfo { export class WebGpuBackend { adapterInfo: AdapterInfoImpl; device: GPUDevice; - deviceInfo: DeviceInfoImpl; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping */ @@ -274,13 +252,9 @@ export class WebGpuBackend { } requireFeatureIfAvailable('shader-f16'); // Try subgroups - if (requireFeatureIfAvailable('subgroups' as GPUFeatureName)) { - // If subgroups feature is available, also try subgroups-f16 - requireFeatureIfAvailable('subgroups-f16' as GPUFeatureName); - } + requireFeatureIfAvailable('subgroups' as GPUFeatureName); this.device = await adapter.requestDevice(deviceDescriptor); - this.deviceInfo = new DeviceInfoImpl(this.device); this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo())); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 2b9a9208e2e53..55784ae13ad7a 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -314,7 +314,8 @@ export class WebNNBackend { bufferView = new Float32Array(buffer); break; case 'float16': - bufferView = new Uint16Array(buffer); + bufferView = + typeof Float16Array !== 'undefined' && Float16Array.from ? new Float16Array(buffer) : new Uint16Array(buffer); break; case 'int32': bufferView = new Int32Array(buffer); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index b4071eae51c8f..8ab6b054bf8a7 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -1,23 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { Env } from 'onnxruntime-common'; +import type { Env } from 'onnxruntime-common'; import { calculateTensorSizeInBytes, DataType } from '../wasm-common'; import type { OrtWasmModule } from '../wasm-types'; -import { WebGpuBackend } from './backend-webgpu'; +import type { WebGpuBackend } from './backend-webgpu'; import { LOG_DEBUG } from './log'; -import { TensorView } from './tensor-view'; +import type { TensorView } from './tensor-view'; import { ShapeUtil } from './util'; -import { - AdapterInfo, - ComputeContext, - ComputeContextInputsOutputsMapping, - DeviceInfo, - ProgramInfo, -} from './webgpu/types'; +import type { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types'; import { WebNNBackend } from './backend-webnn'; /* eslint-disable no-bitwise */ @@ -76,7 +70,6 @@ class TensorViewImpl implements TensorView { class ComputeContextImpl implements ComputeContext { readonly adapterInfo: AdapterInfo; - readonly deviceInfo: DeviceInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -94,7 +87,6 @@ class ComputeContextImpl implements ComputeContext { contextDataOffset: number, ) { this.adapterInfo = backend.adapterInfo; - this.deviceInfo = backend.deviceInfo; // extract context data const ptrSize = module.PTR_SIZE; @@ -205,79 +197,83 @@ export const init = async ( } if (name === 'webgpu') { - const backend = new WebGpuBackend(); - await backend.initialize(env, gpuAdapter!); + if (!BUILD_DEFS.USE_WEBGPU_EP) { + // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + const webGpuBackendImpl = require('./backend-webgpu').WebGpuBackend; + const backend = new webGpuBackendImpl(); + await backend.initialize(env, gpuAdapter!); - jsepInit('webgpu', [ - // backend - backend, + jsepInit('webgpu', [ + // backend + backend, + + // jsepAlloc() + (size: number) => backend.alloc(Number(size)), - // jsepAlloc() - (size: number) => backend.alloc(Number(size)), + // jsepFree() + (ptr: number) => backend.free(ptr), - // jsepFree() - (ptr: number) => backend.free(ptr), + // jsepCopy(src, dst, size, isSourceGpu) + (src: number, dst: number, size: number, isSourceGpu = false) => { + if (isSourceGpu) { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, + ); + backend.memcpy(Number(src), Number(dst)); + } else { + LOG_DEBUG( + 'verbose', + () => + `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, + ); + const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); + backend.upload(Number(dst), data); + } + }, - // jsepCopy(src, dst, size, isSourceGpu) - (src: number, dst: number, size: number, isSourceGpu = false) => { - if (isSourceGpu) { + // jsepCopyAsync(src, dst, size) + async (gpuDataId: number, dataOffset: number, size: number): Promise => { LOG_DEBUG( 'verbose', - () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, + () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, ); - backend.memcpy(Number(src), Number(dst)); - } else { - LOG_DEBUG( - 'verbose', - () => - `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, - ); - const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); - backend.upload(Number(dst), data); - } - }, - // jsepCopyAsync(src, dst, size) - async (gpuDataId: number, dataOffset: number, size: number): Promise => { - LOG_DEBUG( - 'verbose', - () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, - ); - - await backend.download(Number(gpuDataId), () => - module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), - ); - }, + await backend.download(Number(gpuDataId), () => + module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), + ); + }, - // jsepCreateKernel - (kernelType: string, kernelId: number, attribute: unknown) => - backend.createKernel( - kernelType, - Number(kernelId), - attribute, - module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), - ), + // jsepCreateKernel + (kernelType: string, kernelId: number, attribute: unknown) => + backend.createKernel( + kernelType, + Number(kernelId), + attribute, + module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), + ), - // jsepReleaseKernel - (kernel: number) => backend.releaseKernel(kernel), + // jsepReleaseKernel + (kernel: number) => backend.releaseKernel(kernel), - // jsepRun - (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { - LOG_DEBUG( - 'verbose', - () => - `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, - ); - const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); - return backend.computeKernel(Number(kernel), context, errors); - }, - // jsepCaptureBegin - () => backend.captureBegin(), - // jsepCaptureEnd - () => backend.captureEnd(), - // jsepReplay - () => backend.replay(), - ]); + // jsepRun + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { + LOG_DEBUG( + 'verbose', + () => + `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, + ); + const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); + return backend.computeKernel(Number(kernel), context, errors); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay(), + ]); + } } else { const backend = new WebNNBackend(env); jsepInit('webnn', [ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ad1de42106d6d..50620cea33863 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -46,6 +46,11 @@ export const createConvTranspose2DProgramInfo = ( const inputChannelsPerGroup = wShape[2] / group; const outputChannelsPerGroup = wShape[3]; const aComponents = isChannelsLast ? getMaxComponents(inputChannelsPerGroup) : 1; + const packInputAs4 = isChannelsLast && outputChannelsPerGroup === 1 && inputChannelsPerGroup >= 4; + const inputChannelsPerGroupInt = packInputAs4 + ? Math.floor(inputChannelsPerGroup / 4) * 4 + : Math.floor(inputChannelsPerGroup / aComponents) * aComponents; + const inputChannelsRemainder = inputChannelsPerGroup - inputChannelsPerGroupInt; const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1; const bComponents = isChannelsLast ? (outputChannelsPerGroup === 1 ? aComponents : components) : 1; const outputSize = ShapeUtil.size(outputShape) / components; @@ -78,6 +83,7 @@ export const createConvTranspose2DProgramInfo = ( { type: DataType.uint32, data: dilations }, { type: DataType.uint32, data: effectiveFilterDims }, { type: DataType.int32, data: pads }, + { type: DataType.uint32, data: inputChannelsPerGroupInt }, { type: DataType.uint32, data: inputChannelsPerGroup }, { type: DataType.uint32, data: outputChannelsPerGroup }, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims), @@ -96,6 +102,7 @@ export const createConvTranspose2DProgramInfo = ( { name: 'dilations', type: 'u32', length: filterDims.length }, { name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length }, { name: 'pads', type: 'i32', length: pads.length }, + { name: 'input_channels_per_group_int', type: 'u32' }, { name: 'input_channels_per_group', type: 'u32' }, { name: 'output_channels_per_group', type: 'u32' }, ]; @@ -114,16 +121,40 @@ export const createConvTranspose2DProgramInfo = ( const calculateResult = (): string => { let calcStr = ''; - if (aComponents === 1) { - calcStr += ` - let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)}; - let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)}; - dotProd = dotProd + xValue * wValue;`; + if (packInputAs4) { + if (aComponents === 4) { + calcStr += ` + let xValue = ${dy.getByOffset('x_offset')}; + let wValue = ${w.getByOffset('w_offset')}; + dotProd = dotProd + dot(xValue, wValue); + x_offset += 1u; + w_offset += 1u;`; + } else if (aComponents === 2) { + calcStr += ` + dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')})); + x_offset += 2u; + w_offset += 2u;`; + } else if (aComponents === 1) { + calcStr += ` + dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}, ${dy.getByOffset('x_offset + 2u')}, ${dy.getByOffset('x_offset + 3u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}, ${w.getByOffset('w_offset + 2u')}, ${w.getByOffset('w_offset + 3u')})); + x_offset += 4u; + w_offset += 4u;`; + } } else { - if (outputChannelsPerGroup === 1) { + calcStr += ` + let xValue = ${ + isChannelsLast + ? dy.getByOffset( + `${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`, + ) + : dy.get('batch', 'inputChannel', 'idyR', 'idyC') + }; + `; + if (aComponents === 1) { calcStr += ` - let wValue = ${w.getByOffset(`${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)} / ${bComponents}`)}; - dotProd = dotProd + dot(xValue, wValue);`; + let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)}; + let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)}; + dotProd = dotProd + xValue * wValue;`; } else { for (let c = 0; c < aComponents; c++) { calcStr += ` @@ -134,6 +165,32 @@ export const createConvTranspose2DProgramInfo = ( } return calcStr; }; + const calculateRemainder = (): string => { + if (inputChannelsRemainder === 0) { + return ''; + } + if (!packInputAs4) { + throw new Error(`packInputAs4 ${packInputAs4} is not true.`); + } + let calcStr = ''; + if (aComponents === 1) { + calcStr += 'dotProd = dotProd'; + for (let i = 0; i < inputChannelsRemainder; i++) { + calcStr += ` + + ${dy.getByOffset(`x_offset + ${i}`)} * ${w.getByOffset(`w_offset + ${i}`)}`; + } + calcStr += ';'; + } else if (aComponents === 2) { + if (inputChannelsRemainder !== 2) { + throw new Error(`Invalid inputChannelsRemainder ${inputChannelsRemainder}.`); + } + calcStr += ` + let xValue = ${dy.getByOffset('x_offset')}; + let wValue = ${w.getByOffset('w_offset')}; + dotProd = dotProd + dot(xValue, wValue);`; + } + return calcStr; + }; const codeSnippet = ` let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)}; let batch = ${output.indicesGet('outputIndices', 0)}; @@ -169,7 +226,6 @@ export const createConvTranspose2DProgramInfo = ( // Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0 wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner); } - for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { if (wC % uniforms.dilations.y != 0) { continue; @@ -182,17 +238,19 @@ export const createConvTranspose2DProgramInfo = ( } let idyC: u32 = u32(dyC); var inputChannel = groupId * uniforms.input_channels_per_group; - for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + ${aComponents}) { - let xValue = ${ - isChannelsLast - ? dy.getByOffset( - `${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`, - ) - : dy.get('batch', 'inputChannel', 'idyR', 'idyC') - }; + ${ + packInputAs4 + ? ` + var x_offset = ${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}; + var w_offset = ${w.indicesToOffset(`${w.type.indices}(wRPerm, wCPerm, inputChannel, wOutChannel)`)} / ${bComponents}; + ` + : '' + } + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group_int; d2 = d2 + ${packInputAs4 ? 4 : aComponents}) { ${calculateResult()} - inputChannel = inputChannel + ${aComponents}; + inputChannel = inputChannel + ${packInputAs4 ? 4 : aComponents}; } + ${calculateRemainder()} wC = wC + uniforms.strides.y - 1; } wR = wR + uniforms.strides[0] - 1; @@ -211,7 +269,7 @@ export const createConvTranspose2DProgramInfo = ( return { name: 'ConvTranspose2D', shaderCache: { - hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${outputChannelsPerGroup === 1}`, + hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${packInputAs4}${inputChannelsRemainder}`, inputDependencies, }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 6a78c8ae3b190..6a8dffb73fa08 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -433,7 +433,7 @@ const createInPlaceSoftmaxProgramInfo = ( getShaderSource, getRunData: () => ({ outputs: [], - dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads }, + dispatchGroup: { x: 1, y: sequenceLength, z: batchSize * numHeads }, programUniforms, }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 2c5180c5db3ee..18d505f57655a 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -99,7 +99,6 @@ export class ProgramManager { const extensionsInfo: Array<{ feature: GPUFeatureName; extension: string }> = [ { feature: 'shader-f16', extension: 'f16' }, { feature: 'subgroups' as GPUFeatureName, extension: 'subgroups' }, - { feature: 'subgroups-f16' as GPUFeatureName, extension: 'subgroups_f16' }, ]; extensionsInfo.forEach((info) => { if (device.features.has(info.feature)) { diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 9321ac170d036..f3cfc6cb98cae 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -21,11 +21,6 @@ export interface AdapterInfo { isArchitecture: (architecture: GpuArchitecture) => boolean; isVendor: (vendor: GpuVendor) => boolean; } -export interface DeviceInfo { - readonly subgroupsSupported: boolean; - readonly subgroupsF16Supported: boolean; - readonly subgroupSizeRange?: readonly [number, number]; -} export interface GpuData { type: GpuDataType; @@ -165,11 +160,6 @@ export interface ComputeContext { */ readonly adapterInfo: AdapterInfo; - /** - * gpu device info - */ - readonly deviceInfo: DeviceInfo; - /** * stores the pointer to OpKernelContext */ diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 5d97bb83e3475..30b1f5101e5f2 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -12,7 +12,11 @@ import { } from './proxy-messages'; import * as core from './wasm-core-impl'; import { initializeWebAssembly } from './wasm-factory'; -import { importProxyWorker, inferWasmPathPrefixFromScriptSrc } from './wasm-utils-import'; +import { + importProxyWorker, + inferWasmPathPrefixFromScriptSrc, + isEsmImportMetaUrlHardcodedAsFileUri, +} from './wasm-utils-import'; const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined'; let proxyWorker: Worker | undefined; @@ -116,7 +120,7 @@ export const initializeWebAssemblyAndOrtRuntime = async (): Promise => { BUILD_DEFS.IS_ESM && BUILD_DEFS.ENABLE_BUNDLE_WASM_JS && !message.in!.wasm.wasmPaths && - (objectUrl || BUILD_DEFS.ESM_IMPORT_META_URL?.startsWith('file:')) + (objectUrl || isEsmImportMetaUrlHardcodedAsFileUri) ) { // for a build bundled the wasm JS, if either of the following conditions is met: // - the proxy worker is loaded from a blob URL diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 17e564247863d..89a4484e5a1c4 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { InferenceSession } from 'onnxruntime-common'; +import type { InferenceSession } from 'onnxruntime-common'; import { getInstance } from './wasm-factory'; import { allocWasmString, checkLastError, iterateExtraOptions } from './wasm-utils'; @@ -54,13 +54,28 @@ const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => } }; -const setExecutionProviders = ( +const appendSessionConfig = (sessionOptionsHandle: number, key: string, value: string, allocs: number[]): void => { + const keyDataOffset = allocWasmString(key, allocs); + const valueDataOffset = allocWasmString(value, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError(`Can't set a session config entry: ${key} - ${value}.`); + } +}; + +const appendEpOption = (epOptions: Array<[number, number]>, key: string, value: string, allocs: number[]): void => { + const keyDataOffset = allocWasmString(key, allocs); + const valueDataOffset = allocWasmString(value, allocs); + epOptions.push([keyDataOffset, valueDataOffset]); +}; + +const setExecutionProviders = async ( sessionOptionsHandle: number, executionProviders: readonly InferenceSession.ExecutionProviderConfig[], allocs: number[], -): void => { +): Promise => { for (const ep of executionProviders) { let epName = typeof ep === 'string' ? ep : ep.name; + const epOptions: Array<[number, number]> = []; // check EP name switch (epName) { @@ -71,26 +86,44 @@ const setExecutionProviders = ( // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; if (deviceType) { - const keyDataOffset = allocWasmString('deviceType', allocs); - const valueDataOffset = allocWasmString(deviceType, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); - } + appendSessionConfig(sessionOptionsHandle, 'deviceType', deviceType, allocs); } } break; case 'webgpu': - epName = 'JS'; - if (typeof ep !== 'string') { - const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; - if (webgpuOptions?.preferredLayout) { - if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { - throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); + if (BUILD_DEFS.USE_WEBGPU_EP) { + epName = 'WebGPU'; + let customDevice: GPUDevice | undefined; + + if (typeof ep !== 'string') { + const customOptions = ep as unknown as { device: GPUDevice }; + if (customOptions.device) { + if (typeof GPUDevice !== 'undefined' && customOptions.device instanceof GPUDevice) { + customDevice = customOptions.device; + } else { + throw new Error('Invalid GPU device set in WebGPU EP options.'); + } } - const keyDataOffset = allocWasmString('preferredLayout', allocs); - const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); + + // TODO: handle more options + } + + const info = getInstance().webgpuRegisterDevice!(customDevice); + if (info) { + const [deviceId, instanceHandle, deviceHandle] = info; + appendEpOption(epOptions, 'deviceId', deviceId.toString(), allocs); + appendEpOption(epOptions, 'webgpuInstance', instanceHandle.toString(), allocs); + appendEpOption(epOptions, 'webgpuDevice', deviceHandle.toString(), allocs); + } + } else { + epName = 'JS'; + if (typeof ep !== 'string') { + const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; + if (webgpuOptions?.preferredLayout) { + if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { + throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); + } + appendSessionConfig(sessionOptionsHandle, 'preferredLayout', webgpuOptions.preferredLayout, allocs); } } } @@ -103,13 +136,34 @@ const setExecutionProviders = ( } const epNameDataOffset = allocWasmString(epName, allocs); - if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) { + const epOptionsCount = epOptions.length; + let keysOffset = 0; + let valuesOffset = 0; + if (epOptionsCount > 0) { + keysOffset = getInstance()._malloc(epOptionsCount * getInstance().PTR_SIZE); + allocs.push(keysOffset); + valuesOffset = getInstance()._malloc(epOptionsCount * getInstance().PTR_SIZE); + allocs.push(valuesOffset); + for (let i = 0; i < epOptionsCount; i++) { + getInstance().setValue(keysOffset + i * getInstance().PTR_SIZE, epOptions[i][0], '*'); + getInstance().setValue(valuesOffset + i * getInstance().PTR_SIZE, epOptions[i][1], '*'); + } + } + if ( + (await getInstance()._OrtAppendExecutionProvider( + sessionOptionsHandle, + epNameDataOffset, + keysOffset, + valuesOffset, + epOptionsCount, + )) !== 0 + ) { checkLastError(`Can't append execution provider: ${epName}.`); } } }; -export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => { +export const setSessionOptions = async (options?: InferenceSession.SessionOptions): Promise<[number, number[]]> => { const wasm = getInstance(); let sessionOptionsHandle = 0; const allocs: number[] = []; @@ -155,20 +209,19 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n } if (sessionOptions.executionProviders) { - setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); + await setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } if (sessionOptions.enableGraphCapture !== undefined) { if (typeof sessionOptions.enableGraphCapture !== 'boolean') { throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); } - const keyDataOffset = allocWasmString('enableGraphCapture', allocs); - const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); - if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError( - `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`, - ); - } + appendSessionConfig( + sessionOptionsHandle, + 'enableGraphCapture', + sessionOptions.enableGraphCapture.toString(), + allocs, + ); } if (sessionOptions.freeDimensionOverrides) { @@ -188,12 +241,7 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n if (sessionOptions.extra !== undefined) { iterateExtraOptions(sessionOptions.extra, '', new WeakSet>(), (key, value) => { - const keyDataOffset = allocWasmString(key, allocs); - const valueDataOffset = allocWasmString(value, allocs); - - if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: ${key} - ${value}.`); - } + appendSessionConfig(sessionOptionsHandle, key, value, allocs); }); } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 4bccfa76fdda3..dbcf80adf3552 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -102,11 +102,20 @@ export const initRuntime = async (env: Env): Promise => { * @param epName */ export const initEp = async (env: Env, epName: string): Promise => { + // initialize ASYNCIFY support + getInstance().asyncInit?.(); + + if (epName === 'webgpu' && BUILD_DEFS.USE_WEBGPU_EP) { + getInstance().webgpuInit!((device) => { + env.webgpu.device = device; + }); + } + if (!BUILD_DEFS.DISABLE_JSEP) { // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires const initJsep = require('./jsep/init').init; - if (epName === 'webgpu') { + if (epName === 'webgpu' && !BUILD_DEFS.USE_WEBGPU_EP) { // perform WebGPU availability check if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); @@ -270,7 +279,7 @@ export const createSession = async ( const outputNamesUTF8Encoded = []; try { - [sessionOptionsHandle, allocs] = setSessionOptions(options); + [sessionOptionsHandle, allocs] = await setSessionOptions(options); if (options?.externalData && wasm.mountExternalData) { const loadingPromises = []; @@ -278,7 +287,7 @@ export const createSession = async ( const path = typeof file === 'string' ? file : file.path; loadingPromises.push( loadFile(typeof file === 'string' ? file : file.data).then((data) => { - wasm.mountExternalData!(path, data); + wasm.mountExternalData(path, data); }), ); } @@ -312,6 +321,7 @@ export const createSession = async ( } sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); + wasm.webgpuOnCreateSession?.(sessionHandle); if (sessionHandle === 0) { checkLastError("Can't create a session."); } @@ -444,6 +454,7 @@ export const releaseSession = (sessionId: number): void => { } wasm.jsepOnReleaseSession?.(sessionId); + wasm.webgpuOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); @@ -491,11 +502,20 @@ export const prepareInputOutputTensor = async ( const gpuBuffer = tensor[2].gpuBuffer; dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; - const registerBuffer = wasm.jsepRegisterBuffer; - if (!registerBuffer) { - throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); + if (BUILD_DEFS.USE_WEBGPU_EP) { + const registerBuffer = wasm.webgpuRegisterBuffer; + if (!registerBuffer) { + throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); + } + + rawData = registerBuffer(gpuBuffer, sessionId); + } else { + const registerBuffer = wasm.jsepRegisterBuffer; + if (!registerBuffer) { + throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); + } + rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); } - rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); } else if (location === 'ml-tensor') { const mlTensor = tensor[2].mlTensor as MLTensor; dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; @@ -791,7 +811,7 @@ export const run = async ( // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU // tensor for it. There is no mapping GPU buffer for an empty tensor. if (preferredLocation === 'gpu-buffer' && size > 0) { - const getBuffer = wasm.jsepGetBuffer; + const getBuffer = BUILD_DEFS.USE_WEBGPU_EP ? wasm.webgpuGetBuffer : wasm.jsepGetBuffer; if (!getBuffer) { throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.'); } @@ -804,20 +824,43 @@ export const run = async ( // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; - output.push([ - type, - dims, - { - gpuBuffer, - download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type), - dispose: () => { - if (wasm._OrtReleaseTensor(tensor) !== 0) { - checkLastError("Can't release tensor."); - } + if (BUILD_DEFS.USE_WEBGPU_EP) { + wasm.webgpuRegisterBuffer!(gpuBuffer, sessionId, dataOffset); + const downloadDataFunction = wasm.webgpuCreateDownloader!(gpuBuffer, bufferSize, sessionId); + output.push([ + type, + dims, + { + gpuBuffer, + download: async () => { + const arrayBuffer = await downloadDataFunction(); + const data = new (tensorTypeToTypedArrayConstructor(type!))(arrayBuffer); + return data as Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]; + }, + dispose: () => { + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError("Can't release tensor."); + } + }, }, - }, - 'gpu-buffer', - ]); + 'gpu-buffer', + ]); + } else { + output.push([ + type, + dims, + { + gpuBuffer, + download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type), + dispose: () => { + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError("Can't release tensor."); + } + }, + }, + 'gpu-buffer', + ]); + } } else if (preferredLocation === 'ml-tensor' && size > 0) { const ensureTensor = wasm.jsepEnsureTensor; if (!ensureTensor) { @@ -887,6 +930,18 @@ export const run = async ( } finally { wasm.stackRestore(beforeRunStack); + if (BUILD_DEFS.USE_WEBGPU_EP) { + inputTensors.forEach((t) => { + if (t && t[3] === 'gpu-buffer') { + wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer); + } + }); + outputTensors.forEach((t) => { + if (t && t[3] === 'gpu-buffer') { + wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer); + } + }); + } inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); inputOutputAllocs.forEach((p) => wasm._free(p)); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index b4871e145f4d7..9b2ec71fd351d 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -41,18 +41,6 @@ export declare namespace JSEP { type DownloadTensorFunction = (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; export interface Module extends WebGpuModule, WebNnModule { - /** - * Mount the external data file to an internal map, which will be used during session initialization. - * - * @param externalDataFilePath - specify the relative path of the external data file. - * @param externalDataFileData - specify the content data. - */ - mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void; - /** - * Unmount all external data files from the internal map. - */ - unmountExternalData(): void; - /** * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and @@ -294,6 +282,21 @@ export declare namespace JSEP { } } +export declare namespace WebGpu { + export interface Module { + webgpuInit(setDefaultDevice: (device: GPUDevice) => void): void; + webgpuRegisterDevice( + device?: GPUDevice, + ): undefined | [deviceId: number, instanceHandle: number, deviceHandle: number]; + webgpuOnCreateSession(sessionHandle: number): void; + webgpuOnReleaseSession(sessionHandle: number): void; + webgpuRegisterBuffer(buffer: GPUBuffer, sessionHandle: number, bufferHandle?: number): number; + webgpuUnregisterBuffer(buffer: GPUBuffer): void; + webgpuGetBuffer(bufferHandle: number): GPUBuffer; + webgpuCreateDownloader(gpuBuffer: GPUBuffer, size: number, sessionHandle: number): () => Promise; + } +} + export interface OrtInferenceAPIs { _OrtInit(numThreads: number, loggingLevel: number): number; @@ -358,7 +361,13 @@ export interface OrtInferenceAPIs { logVerbosityLevel: number, optimizedModelFilePath: number, ): number; - _OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number; + _OrtAppendExecutionProvider( + sessionOptionsHandle: number, + name: number, + providerOptionsKeys: number, + providerOptionsValues: number, + numKeys: number, + ): Promise; _OrtAddFreeDimensionOverride(sessionOptionsHandle: number, name: number, dim: number): number; _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; _OrtReleaseSessionOptions(sessionOptionsHandle: number): number; @@ -373,8 +382,11 @@ export interface OrtInferenceAPIs { /** * The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten. */ -export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial { - PTR_SIZE: number; +export interface OrtWasmModule + extends EmscriptenModule, + OrtInferenceAPIs, + Partial, + Partial { // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; @@ -387,7 +399,31 @@ export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Parti stringToUTF8(str: string, offset: number, maxBytes: number): void; // #endregion + // #region ORT shared + + readonly PTR_SIZE: 4 | 8; + + /** + * Mount the external data file to an internal map, which will be used during session initialization. + * + * @param externalDataFilePath - specify the relative path of the external data file. + * @param externalDataFileData - specify the content data. + */ + mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void; + /** + * Unmount all external data files from the internal map. + */ + unmountExternalData(): void; + + /** + * This function patches the WebAssembly module to support Asyncify. This function should be called at least once + * before any ORT API is called. + */ + asyncInit?(): void; + + // #endregion + // #region config - numThreads?: number; + readonly numThreads?: number; // #endregion } diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index 871b575d71edc..a8e27f6f334bc 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -11,6 +11,39 @@ import { isNode } from './wasm-utils-env'; */ const origin = isNode || typeof location === 'undefined' ? undefined : location.origin; +/** + * Some bundlers (eg. Webpack) will rewrite `import.meta.url` to a file URL at compile time. + * + * This function checks if `import.meta.url` starts with `file:`, but using the `>` and `<` operators instead of + * `startsWith` function so that code minimizers can remove the dead code correctly. + * + * For example, if we use terser to minify the following code: + * ```js + * if ("file://hard-coded-filename".startsWith("file:")) { + * console.log(1) + * } else { + * console.log(2) + * } + * + * if ("file://hard-coded-filename" > "file:" && "file://hard-coded-filename" < "file;") { + * console.log(3) + * } else { + * console.log(4) + * } + * ``` + * + * The minified code will be: + * ```js + * "file://hard-coded-filename".startsWith("file:")?console.log(1):console.log(2),console.log(3); + * ``` + * + * (use Terser 5.39.0 with default options, https://try.terser.org/) + * + * @returns true if the import.meta.url is hardcoded as a file URI. + */ +export const isEsmImportMetaUrlHardcodedAsFileUri = + BUILD_DEFS.IS_ESM && BUILD_DEFS.ESM_IMPORT_META_URL! > 'file:' && BUILD_DEFS.ESM_IMPORT_META_URL! < 'file;'; + const getScriptSrc = (): string | undefined => { // if Nodejs, return undefined if (isNode) { @@ -26,9 +59,22 @@ const getScriptSrc = (): string | undefined => { // new URL('actual-bundle-name.js', import.meta.url).href // ``` // So that bundler can preprocess the URL correctly. - if (BUILD_DEFS.ESM_IMPORT_META_URL?.startsWith('file:')) { + if (isEsmImportMetaUrlHardcodedAsFileUri) { // if the rewritten URL is a relative path, we need to use the origin to resolve the URL. - return new URL(new URL(BUILD_DEFS.BUNDLE_FILENAME, BUILD_DEFS.ESM_IMPORT_META_URL).href, origin).href; + + // The following is a workaround for Vite. + // + // Vite uses a bundler(rollup/rolldown) that does not rewrite `import.meta.url` to a file URL. So in theory, this + // code path should not be executed in Vite. However, the bundler does not know it and it still try to load the + // following pattern: + // - `return new URL('filename', import.meta.url).href` + // + // By replacing the pattern above with the following code, we can skip the resource loading behavior: + // - `const URL2 = URL; return new URL2('filename', import.meta.url).href;` + // + // And it still works in Webpack. + const URL2 = URL; + return new URL(new URL2(BUILD_DEFS.BUNDLE_FILENAME, BUILD_DEFS.ESM_IMPORT_META_URL).href, origin).href; } return BUILD_DEFS.ESM_IMPORT_META_URL; diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 6006de62b41b6..98e61c9f87fbb 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -27,7 +27,8 @@ const args = minimist(process.argv.slice(2)); * --bundle-mode=node * Build a single ort-web bundle for nodejs. */ -const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = args['bundle-mode'] || 'prod'; +const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = + process.env.npm_config_bundle_mode || args['bundle-mode'] || 'prod'; /** * --debug @@ -41,7 +42,18 @@ const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = args['bundle-mode'] || 'pr * Enable debug mode. In this mode, esbuild metafile feature will be enabled. Full bundle analysis will be saved to a * file as JSON. */ -const DEBUG = args.debug; // boolean|'verbose'|'save' +const DEBUG = process.env.npm_config_debug || args.debug; // boolean|'verbose'|'save' + +/** + * --webgpu-ep + * --no-webgpu-ep (default) + * + * Enable or disable the use of WebGPU EP. If enabled, the WebGPU EP will be used. If disabled, the WebGPU backend will + * be used with JSEP. + * + * (temporary) This flag is used to test the WebGPU EP integration. It will be removed in the future. + */ +const USE_WEBGPU_EP = process.env.npm_config_webgpu_ep ?? args['webgpu-ep'] ?? false; /** * Root folder of the source code: `/js/` @@ -57,6 +69,7 @@ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'false', + 'BUILD_DEFS.USE_WEBGPU_EP': JSON.stringify(!!USE_WEBGPU_EP), 'BUILD_DEFS.IS_ESM': 'false', 'BUILD_DEFS.ESM_IMPORT_META_URL': 'undefined', @@ -123,13 +136,17 @@ async function minifyWasmModuleJsForBrowser(filepath: string): Promise { // ``` // with: // ``` - // new Worker(import.meta.url.startsWith('file:') - // ? new URL(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url) - // : new URL(import.meta.url), ... + // new Worker((() => { + // const URL2 = URL; + // return import.meta.url > 'file:' && import.meta.url < 'file;' + // ? new URL2(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url) + // : new URL(import.meta.url); + // })(), ... // ``` // // NOTE: this is a workaround for some bundlers that does not support runtime import.meta.url. - // TODO: in emscripten 3.1.61+, need to update this code. + // + // Check more details in the comment of `isEsmImportMetaUrlHardcodedAsFileUri()` and `getScriptSrc()` in file `lib/wasm/wasm-utils-import.ts`. // First, check if there is exactly one occurrence of "new Worker(new URL(import.meta.url)". const matches = [...contents.matchAll(/new Worker\(new URL\(import\.meta\.url\),/g)]; @@ -142,7 +159,12 @@ async function minifyWasmModuleJsForBrowser(filepath: string): Promise { // Replace the only occurrence. contents = contents.replace( /new Worker\(new URL\(import\.meta\.url\),/, - `new Worker(import.meta.url.startsWith('file:')?new URL(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url):new URL(import.meta.url),`, + `new Worker((() => { + const URL2 = URL; + return (import.meta.url > 'file:' && import.meta.url < 'file;') + ? new URL2(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url) + : new URL(import.meta.url); + })(),`, ); // Use terser to minify the code with special configurations: diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index 6429845d23df9..008d58530ee36 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -348,6 +348,128 @@ } ] }, + { + "name": "ConvTranspose NHWC- group - A", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [1, 1], "type": "ints" }, + { "name": "group", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0], + "dims": [1, 2, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0], + "dims": [2, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 36, 40, 44, 48, 52, 56, 60, 64, 68], + "dims": [1, 2, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose NHWC- group - B", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + }, + { + "data": [0.125, 0.25, 0.375], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.125, 1.125, 4.125, 4.125, 3.125, 13.125, 23.125, 18.125, 15.125, 43.125, 53.125, 36.125, 18.125, 45.125, + 52.125, 32.125, 45.25, 104.25, 115.25, 66.25, 123.25, 279.25, 305.25, 172.25, 159.25, 357.25, 383.25, + 214.25, 105.25, 232.25, 247.25, 136.25, 162.375, 351.375, 370.375, 200.375, 387.375, 833.375, 875.375, + 470.375, 231.375, 494.375, 517.375, 276.375, 0.375, 0.375, 0.375, 0.375 + ], + "dims": [1, 3, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose NHWC- group - C", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 + ], + "dims": [1, 3, 3, 4], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0, 1, 4, 7, 6, 4, 16, 26, 36, 26, 20, 56, 66, 76, 50, 24, 59, 66, 73, 44, 60, 137, 148, 159, 90, 164, 368, + 394, 420, 234, 212, 472, 498, 524, 290, 140, 307, 322, 337, 184, 216, 465, 484, 503, 270, 516, 1104, 1146, + 1188, 634, 596, 1272, 1314, 1356, 722, 352, 747, 770, 793, 420 + ], + "dims": [1, 3, 4, 5], + "type": "float32" + } + ] + } + ] + }, { "name": "ConvTranspose with bias addition C", "operator": "ConvTranspose", diff --git a/js/web/test/e2e/exports/main.js b/js/web/test/e2e/exports/main.js index 8ed22a6784e7c..d8c7bbf69039f 100644 --- a/js/web/test/e2e/exports/main.js +++ b/js/web/test/e2e/exports/main.js @@ -3,7 +3,7 @@ 'use strict'; -const { runDevTest, runProdTest } = require('./test'); +const { runDevTest, runProdTest, verifyAssets } = require('./test'); const { installOrtPackages } = require('./utils'); /** @@ -29,5 +29,14 @@ module.exports = async function main(PRESERVE, PACKAGES_TO_INSTALL) { await runDevTest('vite-default', '\x1b[32m➜\x1b[39m \x1b[1mLocal\x1b[22m:', 5173); await runProdTest('vite-default', '\x1b[32m➜\x1b[39m \x1b[1mLocal\x1b[22m:', 4173); + + await verifyAssets('vite-default', async (cwd) => { + const globby = await import('globby'); + + return { + test: 'File "dist/assets/**/ort.*.mjs" should not exist', + success: globby.globbySync('dist/assets/**/ort.*.mjs', { cwd }).length === 0, + }; + }); } }; diff --git a/js/web/test/e2e/exports/test.js b/js/web/test/e2e/exports/test.js index 9c5ed745ab0b5..e2bcffea97519 100644 --- a/js/web/test/e2e/exports/test.js +++ b/js/web/test/e2e/exports/test.js @@ -121,7 +121,29 @@ async function runProdTest(testCaseName, ready, port) { await runTest(testCaseName, ['prod'], ready, 'npm run start', port); } +async function verifyAssets(testCaseName, testers) { + testers = Array.isArray(testers) ? testers : [testers]; + const wd = path.join(__dirname, 'testcases', testCaseName); + + console.log(`[${testCaseName}] Verifying assets...`); + + const testResults = []; + + try { + for (const tester of testers) { + testResults.push(await tester(wd)); + } + + if (testResults.some((r) => !r.success)) { + throw new Error(`[${testCaseName}] asset verification failed.`); + } + } finally { + console.log(`[${testCaseName}] asset verification result:`, testResults); + } +} + module.exports = { runDevTest, runProdTest, + verifyAssets, }; diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc b/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc new file mode 100644 index 0000000000000..65c14e8cb0bdd --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/bert/bias_add.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + BiasAdd, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + BiasAdd); + +Status BiasAddProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& input = shader.AddInput("input"); + const ShaderVariableHelper& bias = shader.AddInput("bias"); + const ShaderVariableHelper& residual = shader.AddInput("residual"); + const ShaderVariableHelper& output = shader.AddOutput("output"); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let value = " << input.GetByOffset("global_idx") + << " + " << bias.GetByOffset("global_idx % uniforms.channels") + << " + " << residual.GetByOffset("global_idx") << ";\n" + << output.SetByOffset("global_idx", "value"); + + return Status::OK(); +} + +static int64_t GetMaxComponents(int64_t size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +Status BiasAdd::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* bias = context.Input(1); + const auto* residual = context.Input(2); + + TensorShape input_shape = input->Shape(); + + if (input_shape.NumDimensions() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasAdd input should have 3 dimensions."); + } + + int64_t channels = input_shape[2]; + int64_t components = GetMaxComponents(channels); + channels /= components; + + TensorShape bias_shape = bias->Shape(); + if (bias_shape.NumDimensions() != 1 || bias_shape[0] != channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasAdd bias should have 1 dimension with size equal to the number of channels."); + } + + auto* output = context.Output(0, input_shape); + int64_t output_size = output->Shape().Size() / components; + + BiasAddProgram program{}; + program.AddInputs({{input}, {bias}, {residual}}) + .AddOutput({output}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {static_cast(channels)}}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_add.h b/onnxruntime/contrib_ops/webgpu/bert/bias_add.h new file mode 100644 index 0000000000000..58cc5f09f8003 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/bias_add.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class BiasAddProgram final : public Program { + public: + BiasAddProgram() : Program{"BiasAdd"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"channels", ProgramUniformVariableDataType::Uint32}); +}; + +class BiasAdd final : public WebGpuKernel { + public: + BiasAdd(const OpKernelInfo& info) : WebGpuKernel(info) {} + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index a5cae7e7f6747..29ea4f81dd5e1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -50,7 +50,7 @@ Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) c const auto* bias = context.Input(1); auto* output = context.Output(0, input->Shape()); - uint32_t data_size = gsl::narrow(output->Shape().Size()); + uint32_t data_size = onnxruntime::narrow(output->Shape().Size()); if (data_size == 0) { return Status::OK(); } @@ -60,7 +60,7 @@ Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) c int bias_components = 1; if (bias != nullptr) { - bias_size = gsl::narrow(bias->Shape().Size()); + bias_size = onnxruntime::narrow(bias->Shape().Size()); if (bias_size % 4 == 0) { bias_components = 4; bias_size = bias_size / 4; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 57ae8a7e5ba74..1e95d3d9610ff 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -98,7 +98,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, {present_value, ProgramTensorMetadataDependency::Rank, components}}) .AddIndices(valid_present_shape); - program.SetDispatchGroupSize(gsl::narrow(valid_kv_size + 63 / 64)) + program.SetDispatchGroupSize(onnxruntime::narrow(valid_kv_size + 63 / 64)) .SetWorkgroupSize(64) .CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_) .AddUniformVariables({{static_cast(valid_kv_size)}, @@ -379,7 +379,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { if (sg_size > 8) { for (var i:u32 = 0; i < qkv_head_size_vec; i++) { - var val = select(vec4(0), v_tile[capped_sg_id][i], k_start + capped_sg_id < seq_causal_length); + var val = v_tile[capped_sg_id][i]; var sum = subgroupShuffle(val, 0) * qk_1[0]; sum += subgroupShuffle(val, 1) * qk_1[1]; sum += subgroupShuffle(val, 2) * qk_1[2]; diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index bc8b7493fc916..20e1583e0da8f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -66,11 +66,11 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con const auto* sin_cache = context.Input(3); auto* output = context.Output(0, input_shape); - const auto batch_size = gsl::narrow(input->Shape()[0]); - const auto batch_stride = gsl::narrow(input_shape.SizeFromDimension(1)); - const auto sequence_length = gsl::narrow(input_shape[input_shape.NumDimensions() - 2]); + const auto batch_size = onnxruntime::narrow(input->Shape()[0]); + const auto batch_stride = onnxruntime::narrow(input_shape.SizeFromDimension(1)); + const auto sequence_length = onnxruntime::narrow(input_shape[input_shape.NumDimensions() - 2]); const auto hidden_size = batch_stride / sequence_length; - const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); + const auto half_rotary_embedding_dim = onnxruntime::narrow(cos_cache->Shape()[1]); const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape @@ -85,11 +85,11 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con std::vector global_dims(rank); std::vector global_strides(rank); for (size_t j = 0; j < rank; ++j) { - global_dims[j] = gsl::narrow(global_shape[j]); - global_strides[j] = gsl::narrow(global_shape.SizeFromDimension(j + 1)); + global_dims[j] = onnxruntime::narrow(global_shape[j]); + global_strides[j] = onnxruntime::narrow(global_shape.SizeFromDimension(j + 1)); } - const auto output_size = gsl::narrow(global_shape.Size()); + const auto output_size = onnxruntime::narrow(global_shape.Size()); RotaryEmbeddingProgram program{interleaved_}; const auto input_output_strides = input_shape.NumDimensions() == 3 diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc index a1840257d734f..d5d4632c01e2a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -122,7 +122,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo } const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - const uint32_t hidden_size = gsl::narrow(x_shape[x_shape.NumDimensions() - 1]); + const uint32_t hidden_size = onnxruntime::narrow(x_shape[x_shape.NumDimensions() - 1]); const int components = GetMaxComponents(hidden_size); const bool has_input_skip_bias_sum = input_skip_bias_sum != nullptr; @@ -133,7 +133,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) - .SetDispatchGroupSize(gsl::narrow(ceil(1.0 * data_size / hidden_size))) + .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * data_size / hidden_size))) .AddUniformVariables({ {static_cast(components)}, }) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc new file mode 100644 index 0000000000000..05cbfb1f99c48 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -0,0 +1,326 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddOutput("output", ShaderUsage::UseUniform); + shader.AddOutput("scales", ShaderUsage::UseUniform); + shader.AdditionalImplementation() << R"ADDNL_FN( + fn readInput(offset: u32) -> input_a_value_t + { + if (offset > uniforms.input_size) { + return input_a_value_t(0); + } + return input_a[offset]; + } + )ADDNL_FN"; + shader.MainFunctionBody() << R"MAIN_FN( + var local_a : array, 32>; + var max_value:vec4 = vec4(0); + for (var idx:u32=0;idx<32;idx+=1) + { + local_a[idx] = readInput(workgroup_idx*32 + idx); + max_value = max(max_value, abs(local_a[idx])); + } + var scale = max(max_value.x, max_value.y); + scale = max(scale, max_value.z); + scale = max(scale, max_value.w); + for (var idx:u32=0;idx<32;idx+=1) + { + output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); + } + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. + scales[workgroup_idx] = scale/127; + )MAIN_FN"; + return Status::OK(); +} + +Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scales_a", ShaderUsage::UseUniform); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + // This shader implements co-operative matrix multiply. The key idea here is to + // assume there is a primitive for medium size matrix multiply a subgroup can perform, + // using all its lanes and pooling all its registers to keep the values in registry. + // + // The entire workgroup which has N subgroups first loads a tile into shared memory, + // Then each subgroup loads a subtile from shared memory into registers and uses + // the medium size matrix multiply primitive to perform the math. + // The values for tile/subtile size are chosen to conform to the resource limits + // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - + // therefore there are 16 subgroups and 16 lanes in each subgroup. + // K the hidden dimension is paged in from RAM at k tile size which is 64. + // All this puts the shared memory requirement slightly above 16KB. + // WebGPU limit is 16KB, output is moved to registers instead of SHM to make + // everything fit in shared memory. + // + // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with + // subgroup shuffle as a placeholder for the day the medium matrix mul primitive + // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on + // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the + // 512B of registry from each lane. + // + // The medium size matmul is implemented using dot4I8Packed, so the inputs for + // this shader require A to be int8 quantized with block size 64. B is regular + // matmulnbits input with block size 32. + + shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";"; + + shader.AdditionalImplementation() << R"ADDNL_FN( + const tile_size = 64; + const subtile_size = 16; + const tile_size_k = 32; + const vec_factor = 4; + const u32_factor = 4; + const tile_size_k_vec = 2; + + // Shared memory + var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_A : array; // 64 x 1 + var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_B : array; // 64 x 1 + + fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let a_global = a_global_base + row; + if (a_global >= uniforms.M) + { + return; + } + tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; + if (col == 0) + { + // kidx_v - covers 16 values of k + scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; + } + } + + fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let b_global = b_global_base + row; + if (b_global >= uniforms.N) + { + return; + } + + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + if (col == 0) + { + // kidx_v - each kidx_v covers 16 values of k + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; + } + } + + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t + { + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(local_sum) * scale; + } + )ADDNL_FN"; + + shader.MainFunctionBody() << R"MAIN_FN( + // During the load phase we use all 256 threads to load 64 rows of A/B. + // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. + let a_global_base = workgroup_id.x * tile_size; + let b_global_base = workgroup_id.y * tile_size; + let load_AorB = u32(local_idx/128); + let load_row = u32((local_idx%128)/2); + let load_col = u32(local_idx%2); + + // During the compute phase, we have the 64x64 tile split into + // subtiles of 16x16. We have a grid of 4x4 subtiles. + let subtile_id = u32(local_idx / subtile_size); + let subtile_idx = u32(subtile_id / 4); + let subtile_idy = u32(subtile_id % 4); + let base_A = subtile_idx * 16; + let base_B = subtile_idy * 16; + // For each subtile we have 16 threads assigned. + let a_idx = u32(local_idx % subtile_size); + + var lane_output1: vec4; + var lane_output2: vec4; + var lane_output3: vec4; + var lane_output4: vec4; + // K's vectrorization is 16 items per index. See input_a/input_b. + // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is + // k tile size is 32. In vectorized space that is 32/16 = 2. + for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) + { + // Load Phase: Populate shared memory for the workgroup. + if (load_AorB == 0) + { + loadSHMA(a_global_base, kidx_v, load_row, load_col); + } + else + { + loadSHMB(b_global_base, kidx_v, load_row, load_col); + } + workgroupBarrier(); + + // Compute phase: Perform matmul for this subtile 16 x 32 x 16. + // Step 1: Load from shared memory into registers across entire subgroup. + var own_a0: vec4 = tile_A[0][base_A + a_idx]; + var own_a1: vec4 = tile_A[1][base_A + a_idx]; + var own_scale_a: output_element_t = scale_A[base_A + a_idx]; + if (sg_size == 16) + { + var own_b0: vec4 = tile_B[0][base_B + sg_id]; + var own_b1: vec4 = tile_B[1][base_B + sg_id]; + var own_scale_b: output_element_t = scale_B[base_B + sg_id]; + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); + lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); + lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); + lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); + + lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); + lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); + lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); + lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); + + lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); + lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); + lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); + lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); + + lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); + lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); + lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); + lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); + } + else + { + // Code for other subgroup sizes, simply doesnt use subgroups at all. + // Relies on reads from single location tile_B[][base_B + col] by all + // being optimized by the hardware. + lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); + lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); + lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); + lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); + + lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); + lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); + lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); + lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); + + lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); + lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); + lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); + lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); + + lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); + lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); + lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); + lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); + } + workgroupBarrier(); + } + + let a_global = a_global_base + base_A + a_idx; + let b_global = b_global_base + base_B; + let output_idx = ((a_global) * uniforms.N + b_global)/4; + // This creates a shader requirement that uniforms.N % 16 == 0 + if (a_global < uniforms.M && b_global < uniforms.N) + { + output[output_idx] = lane_output1; + output[output_idx+1] = lane_output2; + output[output_idx+2] = lane_output3; + output[output_idx+3] = lane_output4; + } + )MAIN_FN"; + + return Status::OK(); +} + +Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + constexpr uint32_t kVec4Components = 4; + constexpr uint32_t kVec2Components = 2; + constexpr uint32_t kU32Components = 4; + + constexpr uint32_t kBlockSizeA = 128; + DP4AMatMulQuantizeProgram quantize_program; + quantize_program.SetWorkgroupSize(1); + quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); + TensorShape a_quant_shape{1, M, K / kU32Components}; + Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); + TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); + Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); + quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}}) + .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1}, + {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}}) + .AddUniformVariable({static_cast(M * K / kVec4Components)}); + ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); + + constexpr uint32_t kTileSize = 64; + TensorShape reshaped_y_shape{1, M, N / kVec4Components}; + DP4AMatMulNBitsProgram mul_program{block_size}; + mul_program.SetWorkgroupSize(256); + mul_program.SetDispatchGroupSize( + (M + kTileSize - 1) / kTileSize, + (N + kTileSize - 1) / kTileSize, 1); + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, + {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec2Components * kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}, + {static_cast(K / 8)}, + {static_cast(K / 16)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(kVec4Components)}) + .CacheHint("Block" + std::to_string(block_size)); + return context.RunProgram(mul_program); +} + +bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + uint32_t components_k, + bool has_zero_points) { + // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. + // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 + bool use_dp4a = context.Device().HasFeature(wgpu::FeatureName::Subgroups) && + context.AdapterInfo().backendType != wgpu::BackendType::Metal; + return (accuracy_level == 4 && block_size % 32 == 0 && + batch_count == 1 && components_k == 4 && K % 64 == 0 && N % 16 == 0 && + !has_zero_points && use_dp4a); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h new file mode 100644 index 0000000000000..15b86d78301ad --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class DP4AMatMulQuantizeProgram final : public Program { + public: + DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}); +}; + +class DP4AMatMulNBitsProgram final : public Program { + public: + DP4AMatMulNBitsProgram(uint32_t block_size) : Program{"DP4AMatMulNBits"}, block_size_(block_size) {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K8", ProgramUniformVariableDataType::Uint32}, + {"K16", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t block_size_; +}; + +Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y); + +bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + uint32_t components_k, + bool has_zero_points); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 28d622b2c9c33..cce10a59fbd4b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -5,6 +5,7 @@ #include "contrib_ops/webgpu/quantization/matmul_nbits.h" #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/shader_helper.h" @@ -371,7 +372,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } } else { const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); - const int output_element_number = y.NumComponents() * gsl::narrow(output_number_); + const int output_element_number = y.NumComponents() * onnxruntime::narrow(output_number_); const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; std::string offset = "workgroup_idx * " + std::to_string(output_number_); @@ -532,255 +533,6 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("scales", ShaderUsage::UseUniform); - shader.AdditionalImplementation() << R"ADDNL_FN( - fn readInput(offset: u32) -> input_a_value_t - { - if (offset > uniforms.input_size) { - return input_a_value_t(0); - } - return input_a[offset]; - } -)ADDNL_FN"; - shader.MainFunctionBody() << R"MAIN_FN( - var local_a : array, 32>; - var max_value:vec4 = vec4(0); - for (var idx:u32=0;idx<32;idx+=1) - { - local_a[idx] = readInput(workgroup_idx*32 + idx); - max_value = max(max_value, abs(local_a[idx])); - } - var scale = max(max_value.x, max_value.y); - scale = max(scale, max_value.z); - scale = max(scale, max_value.w); - for (var idx:u32=0;idx<32;idx+=1) - { - output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); - } - // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. - scales[workgroup_idx] = scale/127; -)MAIN_FN"; - return Status::OK(); -} - -Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddInput("scales_a", ShaderUsage::UseUniform); - shader.AddInput("input_b", ShaderUsage::UseUniform); - shader.AddInput("scales_b", ShaderUsage::UseUniform); - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - - // This shader implements co-operative matrix multiply. The key idea here is to - // assume there is a primitive for medium size matrix multiply a subgroup can perform, - // using all its lanes and pooling all its registers to keep the values in registry. - // - // The entire workgroup which has N subgroups first loads a tile into shared memory, - // Then each subgroup loads a subtile from shared memory into registers and uses - // the medium size matrix multiply primitive to perform the math. - // The values for tile/subtile size are chosen to conform to the resource limits - // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - - // therefore there are 16 subgroups and 16 lanes in each subgroup. - // K the hidden dimension is paged in from RAM at k tile size which is 64. - // All this puts the shared memory requirement slightly above 16KB. - // WebGPU limit is 16KB, output is moved to registers instead of SHM to make - // everything fit in shared memory. - // - // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with - // subgroup shuffle as a placeholder for the day the medium matrix mul primitive - // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on - // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the - // 512B of registry from each lane. - // - // The medium size matmul is implemented using dot4I8Packed, so the inputs for - // this shader require A to be int8 quantized with block size 64. B is regular - // matmulnbits input with block size 32. - - shader.AdditionalImplementation() << R"ADDNL_FN( - const tile_size = 64; - const subtile_size = 16; - const tile_size_k = 32; - const vec_factor = 4; - const u32_factor = 4; - const tile_size_k_vec = 2; - const block_size = 32; - - // Shared memory - var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 - var scale_A : array; // 64 x 1 - var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 - var scale_B : array; // 64 x 1 - - fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let a_global = a_global_base + row; - if (a_global >= uniforms.M) - { - return; - } - tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; - if (col == 0) - { - // kidx_v - covers 16 values of k - scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; - } - } - - fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let b_global = b_global_base + row; - if (b_global >= uniforms.N) - { - return; - } - - let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; - var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); - var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); - b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - if (col == 0) - { - // kidx_v - each kidx_v covers 16 values of k - scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2]; - } - } - - // Scaled dot product of 8 packed unsigned integers. - fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t - { - var local_sum = dot4I8Packed(a1[0], b1[0]); - local_sum += dot4I8Packed(a1[1], b1[1]); - local_sum += dot4I8Packed(a1[2], b1[2]); - local_sum += dot4I8Packed(a1[3], b1[3]); - local_sum += dot4I8Packed(a2[0], b2[0]); - local_sum += dot4I8Packed(a2[1], b2[1]); - local_sum += dot4I8Packed(a2[2], b2[2]); - local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(local_sum) * scale; - } -)ADDNL_FN"; - - shader.MainFunctionBody() << R"MAIN_FN( - // During the load phase we use all 256 threads to load 64 rows of A/B. - // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. - let a_global_base = workgroup_id.x * tile_size; - let b_global_base = workgroup_id.y * tile_size; - let load_AorB = u32(local_idx/128); - let load_row = u32((local_idx%128)/2); - let load_col = u32(local_idx%2); - - // During the compute phase, we have the 64x64 tile split into - // subtiles of 16x16. We have a grid of 4x4 subtiles. - let subtile_id = u32(local_idx / subtile_size); - let subtile_idx = u32(subtile_id / 4); - let subtile_idy = u32(subtile_id % 4); - let base_A = subtile_idx * 16; - let base_B = subtile_idy * 16; - // For each subtile we have 16 threads assigned. - let a_idx = u32(local_idx % subtile_size); - - var lane_output1: vec4; - var lane_output2: vec4; - var lane_output3: vec4; - var lane_output4: vec4; - // K's vectrorization is 16 items per index. See input_a/input_b. - // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is - // k tile size is 32. In vectorized space that is 32/16 = 2. - for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) - { - // Load Phase: Populate shared memory for the workgroup. - if (load_AorB == 0) - { - loadSHMA(a_global_base, kidx_v, load_row, load_col); - } - else - { - loadSHMB(b_global_base, kidx_v, load_row, load_col); - } - workgroupBarrier(); - - // Compute phase: Perform matmul for this subtile 16 x 32 x 16. - // Step 1: Load from shared memory into registers across entire subgroup. - var own_a0: vec4 = tile_A[0][base_A + a_idx]; - var own_a1: vec4 = tile_A[1][base_A + a_idx]; - var own_scale_a: output_element_t = scale_A[base_A + a_idx]; - if (sg_size == 16) - { - var own_b0: vec4 = tile_B[0][base_B + sg_id]; - var own_b1: vec4 = tile_B[1][base_B + sg_id]; - var own_scale_b: output_element_t = scale_B[base_B + sg_id]; - // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. - lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); - lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); - lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); - lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); - - lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); - lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); - lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); - lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); - - lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); - lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); - lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); - lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); - - lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); - lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); - lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); - lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); - } - else - { - // Code for other subgroup sizes, simply doesnt use subgroups at all. - // Relies on reads from single location tile_B[][base_B + col] by all - // being optimized by the hardware. - lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); - lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); - lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); - lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); - - lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); - lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); - lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); - lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); - - lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); - lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); - lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); - lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); - - lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); - lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); - lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); - lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); - } - workgroupBarrier(); - } - - let a_global = a_global_base + base_A + a_idx; - let b_global = b_global_base + base_B; - let output_idx = ((a_global) * uniforms.N + b_global)/4; - // This creates a shader requirement that uniforms.N % 16 == 0 - if (a_global < uniforms.M && b_global < uniforms.N) - { - output[output_idx] = lane_output1; - output[output_idx+1] = lane_output2; - output[output_idx+2] = lane_output3; - output[output_idx+3] = lane_output4; - } -)MAIN_FN"; - - return Status::OK(); -} - Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -796,16 +548,16 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context TensorShape b_shape({N_, K_}); ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); auto* y = context.Output(0, helper.OutputShape()); - const uint32_t data_size = gsl::narrow(y->Shape().Size()); + const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); if (data_size == 0) { return Status::OK(); } - const uint32_t batch_count = gsl::narrow(helper.OutputOffsets().size()); - const uint32_t M = gsl::narrow(helper.M()); - const uint32_t N = gsl::narrow(helper.N()); - const uint32_t K = gsl::narrow(helper.K()); - const uint32_t block_size = gsl::narrow(block_size_); + const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t N = onnxruntime::narrow(helper.N()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_); constexpr uint32_t nbits = 4; const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; @@ -822,56 +574,17 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); } - const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); - // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. - // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 - const bool use_dp4a = has_subgroup && context.AdapterInfo().backendType != wgpu::BackendType::Metal; - if (accuracy_level_ == 4 && block_size == 32 && - batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 && - !has_zero_points && use_dp4a && M >= kMinMForTileOptimization) { - constexpr uint32_t kVec4Components = 4; - constexpr uint32_t kVec2Components = 2; - constexpr uint32_t kU32Components = 4; - - constexpr uint32_t kBlockSizeA = 128; - DP4AMatMulQuantizeProgram quantize_program; - quantize_program.SetWorkgroupSize(1); - quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); - TensorShape a_quant_shape{1, M, K / kU32Components}; - Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); - TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); - Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); - quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}}) - .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow(1)}, - {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}) - .AddUniformVariable({static_cast(M * K / kVec4Components)}); - ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); - - constexpr uint32_t kTileSize = 64; - TensorShape reshaped_y_shape{1, M, N / kVec4Components}; - DP4AMatMulNBitsProgram mul_program; - mul_program.SetWorkgroupSize(256); - mul_program.SetDispatchGroupSize( - (M + kTileSize - 1) / kTileSize, - (N + kTileSize - 1) / kTileSize, 1); - mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, - {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) - .AddUniformVariables({{static_cast(M)}, - {static_cast(N)}, - {static_cast(K)}, - {static_cast(K / 8)}, - {static_cast(K / 16)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}); - return context.RunProgram(mul_program); + if (M >= kMinMForTileOptimization && + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { + return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, context, y); } // TODO: Support output_number > 1. Some cases are failed when output_number > 1. constexpr uint32_t output_number = 1; const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; + const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); const bool use_subgroup = has_subgroup && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32; - MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow(components_b), has_zero_points, use_subgroup}; + MatMulNBitsProgram program{output_number, block_size, tile_m, static_cast(components_b), has_zero_points, use_subgroup}; if (M > kMinMForTileOptimization && block_size == 32) { components = 1; constexpr uint32_t workgroup_size = 64; @@ -884,7 +597,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context program.CacheHint("T_M" + std::to_string(tile_m) + "Subgroup" + std::to_string(use_subgroup)); } else if (block_size == 32) { components = 1; - constexpr uint32_t workgroup_size = 64; + // TODO: Tune the workgroup size when `M=1`. + constexpr uint32_t workgroup_size = 128; const uint32_t workgroup_y = N % 8 == 0 ? 8 : 1; const uint32_t workgroup_x = workgroup_size / workgroup_y; program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); @@ -900,10 +614,10 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context TensorShape reshaped_y_shape{batch_count, M, N / components}; program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, static_cast(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, static_cast(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, {scales, ProgramTensorMetadataDependency::None}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(components)}) .AddUniformVariable({block_size}); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 3d72629bf6b25..10221e19c7400 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -35,25 +35,6 @@ class MatMulNBitsProgram final : public Program { bool use_subgroup_; }; -class DP4AMatMulQuantizeProgram final : public Program { - public: - DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}); -}; - -class DP4AMatMulNBitsProgram final : public Program { - public: - DP4AMatMulNBitsProgram() : Program{"DP4AMatMulNBits"} {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"M", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"K8", ProgramUniformVariableDataType::Uint32}, - {"K16", ProgramUniformVariableDataType::Uint32}); -}; - class MatMulNBits final : public WebGpuKernel { public: MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 2944a4d61b8ef..cb024d2a758a9 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -185,13 +185,13 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te mul_program.SetDispatchGroupSize( (N + kTileSizeB - 1) / kTileSizeB, (M + kTileSizeA - 1) / kTileSizeA, 1); - mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kU32Components)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) + mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({{static_cast(M)}, {static_cast(N)}, {static_cast(K)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, gsl::narrow(1)}); + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1}); return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 2e7ed5a16a2f0..068a94c7390e2 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -37,8 +37,8 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/compute_capability.h b/onnxruntime/core/framework/compute_capability.h index 5f21ba2f013e0..819264b3960e7 100644 --- a/onnxruntime/core/framework/compute_capability.h +++ b/onnxruntime/core/framework/compute_capability.h @@ -2,8 +2,11 @@ // Licensed under the MIT License. #pragma once +#include #include "core/common/common.h" #include "core/graph/indexed_sub_graph.h" +#include "core/graph/graph.h" +#include "core/optimizer/graph_optimizer_registry.h" namespace onnxruntime { // A structure encodes a subgraph and the method to run it. @@ -21,5 +24,22 @@ struct ComputeCapability { ComputeCapability(std::unique_ptr t_sub_graph) : sub_graph(std::move(t_sub_graph)) {} + + // Optional function to optimize this ComputeCapability. + // This will be called by ORT once the ComputeCapability is assigned to the EP. + std::function + optimization_func; + + // Optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized. + // when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made. + // IndexedSubGraph.nodes: + // - update based on RemovedNode/AddNode calls + // IndexedSubGraph.MetaDef (if present): + // - inputs and outputs will be unchanged + // - constant_initializers MAY change if we constant fold an initializer during optimization + std::vector> nodes_to_optimize; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index 3a937a119d03b..df85daa006a43 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -14,6 +14,7 @@ namespace onnxruntime { std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry&, IResourceAccountant*) const { std::vector> result; for (const auto& node : graph.Nodes()) { diff --git a/onnxruntime/core/framework/external_data_loader.cc b/onnxruntime/core/framework/external_data_loader.cc index fe73a55735631..c577805e69cc4 100644 --- a/onnxruntime/core/framework/external_data_loader.cc +++ b/onnxruntime/core/framework/external_data_loader.cc @@ -60,7 +60,12 @@ common::Status LoadWebAssemblyExternalData(const Env& env, break; case 1: // Load external data to GPU. - Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); + // TODO: use a unified interface for upload external buffer. + if (Module.webgpuUploadExternalBuffer) { + Module.webgpuUploadExternalBuffer(dataIdOrBuffer, data); + } else { + Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); + } break; default: return 4; // Unknown error occurred in memory copy. diff --git a/onnxruntime/core/framework/external_data_loader.h b/onnxruntime/core/framework/external_data_loader.h index 117da7d0a4afa..90d48ca800797 100644 --- a/onnxruntime/core/framework/external_data_loader.h +++ b/onnxruntime/core/framework/external_data_loader.h @@ -42,7 +42,7 @@ class IExternalDataLoader { enum class ExternalDataLoadType { CPU = 0, -#if defined(USE_JSEP) +#if defined(USE_JSEP) || defined(USE_WEBGPU) WEBGPU_BUFFER = 1, #endif }; diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index 1eb7420b44d2c..d3e435c0341b0 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include "core/framework/fallback_cpu_capability.h" #include "core/common/inlined_containers.h" @@ -176,3 +178,5 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe } } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h index bca75adbfd5a7..ddcc1de96d2af 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.h +++ b/onnxruntime/core/framework/fallback_cpu_capability.h @@ -3,6 +3,8 @@ #pragma once +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include #include "core/common/inlined_containers_fwd.h" #include "core/framework/execution_provider.h" // for IExecutionProvider::IKernelLookup @@ -26,3 +28,5 @@ std::unordered_set GetCpuPreferredNodes(const GraphViewer& graph, const logging::Logger& logger); } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 111f8e0a5fc34..ff4d300f665b1 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -142,13 +142,15 @@ struct GetCapabilityForEPParams { std::reference_wrapper debug_graph_fn; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) IResourceAccountant* resource_accountant; + std::reference_wrapper graph_optimizer_registry; }; auto get_capabilities = [](const IExecutionProvider& ep, const GraphViewer& graph_viewer, const IExecutionProvider::IKernelLookup& kernel_lookup, - IResourceAccountant* resource_accountant) { - auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup, resource_accountant); + IResourceAccountant* resource_accountant, + const GraphOptimizerRegistry& graph_optimizer_registry) { + auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, resource_accountant); // In theory an EP could return an empty capability. Remove those. capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), @@ -182,10 +184,11 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l auto& graph = params.graph.get(); auto& capabilities = params.capabilities.get(); + const auto& graph_optimizer_registry = params.graph_optimizer_registry.get(); { const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); if (capabilities.empty()) { return Status::OK(); @@ -223,7 +226,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l capabilities.clear(); const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); // all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities InlinedHashSet new_nodes_in_capabilities; @@ -261,6 +264,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, const KernelRegistryManager& kernel_registry_mgr, const IExecutionProvider& current_ep, + const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger, std::vector>& capabilities) { const auto& ep_type = current_ep.Type(); @@ -272,14 +276,62 @@ static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, logger}; // TODO: Provide EP with a capability to look inside the functions. - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, nullptr); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, nullptr, graph_optimizer_registry); return Status::OK(); } /** - * Check if a node can be placed on a specific provider. - * Do nothing if the node is already assigned + * Check whether the given IndexedSubGraph is available for assigning to a specific provider. + * + */ +static bool IsIndexedSubGraphAvailableForAssignment(Graph& graph, + const IndexedSubGraph& capability, + GraphPartitioner::Mode mode, + const std::string& provider_type) { + // The provider can run a single node in the if not using meta-defs. + if (capability.GetMetaDef() == nullptr && capability.nodes.size() == 1) { + auto* node = graph.GetNode(capability.nodes[0]); + if (nullptr != node && node->GetExecutionProviderType().empty()) { + // The node was not fused or assigned. + return true; + } + return false; + } + + // if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned, + // so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by + // preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to + // and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes + // should never be on those lists. + // + // when the ORT format model is loaded we will process it normally with EP priority being applied for + // whichever EPs are enabled at the time. + // + // e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP. + // We want the ORT format model to be able to be run as efficiently as possible on either platform, + // so we want all the nodes that either may take to be preserved. If we did not do this we would + // need to create one ORT format model for Android and one for iOS. + if (mode == GraphPartitioner::Mode::kAssignOnly) { + return true; + } + + for (auto node_index : capability.nodes) { + const auto* node = graph.GetNode(node_index); + if ((nullptr == node) || + (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) { + // The node was fused or assigned, so that the whole sub-graph will not be assigned to this + // The assumption is that this can only run the sub-graph as a whole unit. + return false; + } + } + + return true; +} + +/** + * Return a fused node or assign the nodes in the indexed subgraph to the current EP. + * * \param graph * \param capability * \param kernel_registry_mgr @@ -298,75 +350,42 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, if (nullptr == capability.GetMetaDef()) { TryAssignSingleNode(graph, capability, provider_type); } else { - // The can run a fused in the . + const bool acc_enabled = capability.IsAccountingEnabled(); + if (mode == GraphPartitioner::Mode::kNormal) { + std::ostringstream oss; + oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++; + std::string node_name = oss.str(); - // Check whether any node in the was already assigned. If so it cannot be stolen as assignment is done - // in order of EP priority - bool sub_graph_available_for_assignment = true; - if (mode != GraphPartitioner::Mode::kAssignOnly) { - // if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned, - // so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by - // preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to - // and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes - // should never be on those lists. - // - // when the ORT format model is loaded we will process it normally with EP priority being applied for - // whichever EPs are enabled at the time. - // - // e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP. - // We want the ORT format model to be able to be run as efficiently as possible on either platform, - // so we want all the nodes that either may take to be preserved. If we did not do this we would - // need to create one ORT format model for Android and one for iOS. - for (auto node_index : capability.nodes) { - const auto* node = graph.GetNode(node_index); - if ((nullptr == node) || - (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) { - // The node was fused or assigned, so that the whole sub-graph will not be assigned to this - // The assumption is that this can only run the sub-graph as a whole unit. - sub_graph_available_for_assignment = false; - break; - } + Node* fused_node = nullptr; + if (fusion_style == IExecutionProvider::FusionStyle::Function) { + fused_node = &graph.FuseSubGraph(capability, node_name); + } else { + // create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed + // through to Compile via a filtered GraphViewer. + fused_node = &graph.BeginFuseSubGraph(capability, node_name); } - } - if (sub_graph_available_for_assignment) { - const bool acc_enabled = capability.IsAccountingEnabled(); - if (mode == GraphPartitioner::Mode::kNormal) { - std::ostringstream oss; - oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++; - std::string node_name = oss.str(); - - Node* fused_node = nullptr; - if (fusion_style == IExecutionProvider::FusionStyle::Function) { - fused_node = &graph.FuseSubGraph(capability, node_name); - } else { - // create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed - // through to Compile via a filtered GraphViewer. - fused_node = &graph.BeginFuseSubGraph(capability, node_name); - } - - fused_node->SetExecutionProviderType(provider_type); - if (acc_enabled) { - // We account for the fused node. We operate under assumption - // that the fused node would use no more memory when the nodes we are fusing. - // and potentially less than that, and therefore, no threshold check is needed here. - // All threshold checks are done within the EP. - capability.ComputeAndAccountForNode(*fused_node); - } + fused_node->SetExecutionProviderType(provider_type); + if (acc_enabled) { + // We account for the fused node. We operate under assumption + // that the fused node would use no more memory when the nodes we are fusing. + // and potentially less than that, and therefore, no threshold check is needed here. + // All threshold checks are done within the EP. + capability.ComputeAndAccountForNode(*fused_node); + } - result = fused_node; - } else { - // assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them. - // This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion - // at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device - // capabilities. - for (size_t i = 0, limit = capability.nodes.size(); i < limit; ++i) { - auto* node = graph.GetNode(capability.nodes[i]); - if (node != nullptr) { - node->SetExecutionProviderType(provider_type); - if (acc_enabled) { - capability.AccountForNode(i); - } + result = fused_node; + } else { + // assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them. + // This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion + // at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device + // capabilities. + for (size_t i = 0, limit = capability.nodes.size(); i < limit; ++i) { + auto* node = graph.GetNode(capability.nodes[i]); + if (node != nullptr) { + node->SetExecutionProviderType(provider_type); + if (acc_enabled) { + capability.AccountForNode(i); } } } @@ -386,7 +405,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, const layout_transformation::DebugGraphFn& debug_graph_fn, - const logging::Logger& logger, IResourceAccountant* resource_accountant) { + const logging::Logger& logger, IResourceAccountant* resource_accountant, + const GraphOptimizerRegistry& graph_optimizer_registry) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -400,7 +420,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn, logger, resource_accountant)); + transform_layout_fn, debug_graph_fn, logger, resource_accountant, graph_optimizer_registry)); } } @@ -424,7 +444,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, mode, std::cref(transform_layout_fn), std::cref(debug_graph_fn), - resource_accountant}; + resource_accountant, + std::ref(graph_optimizer_registry)}; ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { @@ -450,7 +471,30 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, entry->sub_graph->GetMetaDef() != nullptr; })); for (auto& capability : capabilities) { - Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); + // The can run a fused in the . + // Check whether any node in the was already assigned. If so it cannot be stolen as assignment is done + // in order of EP priority + bool sub_graph_available_for_assignment = IsIndexedSubGraphAvailableForAssignment(graph, *capability->sub_graph, mode, type); + + // If the is available to be assigned to the EP and the ComputeCapability has nodes_to_optimize, + // run EP related optimizations and update ComputeCapability. + if (sub_graph_available_for_assignment && !capability->nodes_to_optimize.empty()) { + for (auto& optimization_cc : capability->nodes_to_optimize) { + if (optimization_cc->optimization_func) { + auto status = optimization_cc->optimization_func(graph, *optimization_cc, *capability, graph_optimizer_registry); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, "The optimization function failed to finish."); + } + // #TODO: Handle nested optimization ComputeCapability + } + } + } + + Node* n = nullptr; + if (sub_graph_available_for_assignment) { + n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); + } + if (n != nullptr) { // searching in kernel registries, if no kernel registered for the fused_node, use compile approach if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) { @@ -587,6 +631,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, const KernelRegistryManager& kernel_registry_mgr, Graph& graph, + const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger, InlinedHashSet& not_inlined, size_t& inlined_count) { @@ -603,6 +648,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_mgr, *subgraph, + graph_optimizer_registry, logger, not_inlined, inlined_count)); @@ -627,7 +673,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger, + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, graph_optimizer_registry, logger, capabilities)); for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; @@ -667,23 +713,28 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide } // Validate the ep_context_path to make sure it is file path and check whether the file exist already -static Status EpContextFilePathCheck(const std::string& ep_context_path, - const std::filesystem::path& model_path) { - std::filesystem::path context_cache_path; +static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_path, + const std::filesystem::path& model_path, + std::filesystem::path& context_cache_path) { if (!ep_context_path.empty()) { context_cache_path = ep_context_path; if (!context_cache_path.has_filename()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder."); } } else if (!model_path.empty()) { - context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); + auto pos = model_path.native().find_last_of(ORT_TSTR(".")); + if (pos != std::string::npos) { + context_cache_path = model_path.native().substr(0, pos) + ORT_TSTR("_ctx.onnx"); + } else { + context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); + } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty."); } if (std::filesystem::exists(context_cache_path)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", - context_cache_path, "' exist already."); + context_cache_path, "' exist already. Please remove the EP context model if you want to re-generate it."); } return Status::OK(); @@ -714,15 +765,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers }; std::filesystem::path context_cache_path; - const std::filesystem::path& model_path = graph.ModelPath(); - - if (!ep_context_path.empty()) { - context_cache_path = ep_context_path; - } else if (!model_path.empty()) { - context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty"); - } + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_path, graph.ModelPath(), context_cache_path)); Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path @@ -794,6 +837,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager, const std::optional& acc_map, + const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger) { bool modified_graph = false; @@ -817,7 +861,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, partition_params.debug_graph_fn, - logger, resource_accountant)); + logger, resource_accountant, graph_optimizer_registry)); } // expand any nodes that have an ONNX function definition but no matching ORT kernel. @@ -838,6 +882,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params, KernelRegistryManager& kernel_registry_mgr, IExecutionProvider& current_ep, + const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability @@ -853,7 +898,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param PartitionParams subgraph_partition_params = partition_params; subgraph_partition_params.graph = std::ref(subgraph); ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, - current_ep, logger)); + current_ep, graph_optimizer_registry, logger)); } } @@ -869,7 +914,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param std::cref(partition_params.transform_layout_function), std::cref(partition_params.debug_graph_fn), #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - nullptr + nullptr, + std::ref(graph_optimizer_registry) }; // clang-format on @@ -962,10 +1008,11 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param static Status PartitionOrtFormatModel(const PartitionParams& partition_params, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager, + const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger) { // process full graph with each EP for (const auto& ep : execution_providers) { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, graph_optimizer_registry, logger)); } return Status::OK(); @@ -992,6 +1039,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_manager, graph, + *graph_optimizer_registry_, logger, not_inlined, inlined_count)); @@ -1048,8 +1096,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, std::ref(*fused_kernel_registry), std::ref(fused_node_unique_id), std::cref(transform_layout_function), - std::cref(debug_graph_fn), - }; + std::cref(debug_graph_fn)}; #else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1068,7 +1115,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (ep_context_enabled) { std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); // Check before EP compile graphs - ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath())); + std::filesystem::path context_cache_path; + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_path, graph.ModelPath(), context_cache_path)); } // We use this only if Resource Aware Partitioning is enabled for any of the EPs @@ -1077,7 +1125,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, ORT_RETURN_IF_ERROR(NodeStatsRecorder::CreateAccountants(config_options, graph.ModelPath(), ep_acc_map)); ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, - ep_acc_map, logger)); + ep_acc_map, *graph_optimizer_registry_, logger)); if (ep_context_enabled) { std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); @@ -1091,7 +1139,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, *graph_optimizer_registry_, logger)); } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index d1ef193cf1520..b9d4022cb5a14 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -7,6 +7,7 @@ #include "core/graph/graph.h" #include "core/framework/fuse_nodes_funcs.h" #include "core/framework/transform_layout_functions.h" +#include "core/optimizer/graph_optimizer_registry.h" namespace onnxruntime { @@ -24,9 +25,12 @@ class GraphPartitioner { }; // The order of providers represents the user preference. - GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers) + GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, + const ExecutionProviders& providers, + std::unique_ptr graph_optimizer_registry) : kernel_registry_mgr_(kernel_registry_mgr), - providers_(providers) { + providers_(providers), + graph_optimizer_registry_(std::move(graph_optimizer_registry)) { } // Run partitioning. @@ -64,6 +68,7 @@ class GraphPartitioner { KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; + std::unique_ptr graph_optimizer_registry_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index a884927abddb7..1c446840b7938 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -10,8 +10,8 @@ #include "core/framework/sparse_tensor.h" #include "core/graph/onnx_protobuf.h" #include "core/session/ort_apis.h" +#include "core/session/model_editor_api.h" #include "core/framework/error_code_helper.h" - #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_map_type_info.h" #include "core/framework/onnxruntime_sequence_type_info.h" @@ -40,7 +40,7 @@ OrtTypeInfo::OrtTypeInfo(std::unique_ptr optional_type_info : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info)) {} OrtTypeInfo::OrtTypeInfo(ONNXType type, std::unique_ptr data) noexcept - : type(type), data(std::move(data)) { + : type(type), tensor_type_info(std::move(data)) { } OrtTypeInfo::~OrtTypeInfo() = default; @@ -55,7 +55,9 @@ ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeI ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input, _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN - *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data.get() : nullptr; + *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) + ? input->tensor_type_info.get() + : nullptr; return nullptr; API_IMPL_END } @@ -84,8 +86,8 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeI API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const char** const out, - _Out_ size_t* len) { +ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, + _Out_ const char** const out, _Out_ size_t* len) { API_IMPL_BEGIN *out = type_info->denotation.c_str(); *len = type_info->denotation.size(); @@ -93,6 +95,61 @@ ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* API_IMPL_END } +#if !defined(ORT_MINIMAL_BUILD) +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_TENSOR); + ti->tensor_type_info = tensor_info->Clone(); + *type_info = ti.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_SPARSETENSOR); + ti->tensor_type_info = tensor_info->Clone(); + *type_info = ti.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, + _In_ const OrtTypeInfo* map_value_type, _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_MAP); + ti->map_type_info = std::make_unique(map_key_type, map_value_type->Clone()); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_SEQUENCE); + ti->sequence_type_info = std::make_unique(sequence_type->Clone()); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_OPTIONAL); + ti->optional_type_info = std::make_unique(contained_type->Clone()); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} +#endif // !defined(ORT_MINIMAL_BUILD) + ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { std::unique_ptr p(ptr); } @@ -298,8 +355,8 @@ std::unique_ptr OrtTypeInfo::Clone() const { #endif case ONNX_TYPE_TENSOR: { std::unique_ptr info; - if (data) { - info = data->Clone(); + if (tensor_type_info) { + info = tensor_type_info->Clone(); } result = MakePtr(type, std::move(info)); result->denotation = denotation; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 72d263d5fa442..54bb946e0d36b 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -31,7 +31,7 @@ struct OrtTypeInfo { ONNXType type; std::string denotation; - std::unique_ptr data; + std::unique_ptr tensor_type_info; std::unique_ptr map_type_info; std::unique_ptr sequence_type_info; std::unique_ptr optional_type_info; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 83a353615bc35..9d45ec38e5a32 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -81,6 +81,11 @@ static common::Status ExtDataTensorProtoToTensor(const Env& env, ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, ext_data_buf, ext_data_len, ext_data_deleter, buffered_tensor, &prepacked_for_graph)); + if constexpr (endian::native != endian::little) { + if (!proto_path.empty() && (proto_path.compare(onnxruntime::utils::kTensorProtoMemoryAddressTag) != 0)) { + utils::ConvertRawDataInTensorProto(const_cast(&tensor_proto), ext_data_buf, ext_data_len); + } + } // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -203,13 +208,12 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } } -common::Status AllocateTensor( - const onnxruntime::MemBuffer* m, - std::unique_ptr& p_tensor, - const onnxruntime::DataTypeImpl* const& type, - onnxruntime::TensorShape& tensor_shape, - bool use_device_allocator_for_initializers, - const onnxruntime::AllocatorPtr& alloc) { +common::Status AllocateTensor(const onnxruntime::MemBuffer* m, + std::unique_ptr& p_tensor, + const onnxruntime::DataTypeImpl* const& type, + onnxruntime::TensorShape& tensor_shape, + bool use_device_allocator_for_initializers, + const onnxruntime::AllocatorPtr& alloc) { if (m != nullptr) { p_tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); if (m->GetLen() < p_tensor->SizeInBytes()) { @@ -354,6 +358,7 @@ common::Status SaveInitializedTensors( } ORT_RETURN_IF_ERROR(planner.Trace(entry.first, entry.second)); } + // 2. allocate weight buffer on different locations // planned_initializers_memory_size_in_byte is not actual physical size. // It's the virtual size computed by planner. @@ -386,6 +391,9 @@ common::Status SaveInitializedTensors( if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { ort_value = *(session_options.initializers_to_share_map.at(name)); LOGS(logger, INFO) << "Using user supplied initializer with name (" << name << ")."; + + } else if (graph.GetOrtValueInitializer(name, ort_value)) { + // populated OrtValue from the Graph instance } else { const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second); @@ -397,10 +405,9 @@ common::Status SaveInitializedTensors( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; Tensor* p_tensor = nullptr; - if (auto iter = buffered_tensors.find(name); - iter != buffered_tensors.end()) { - p_tensor = iter->second.release(); - buffered_tensors.erase(iter); + auto buffered_tensors_iter = buffered_tensors.find(name); + if (buffered_tensors_iter != buffered_tensors.end()) { + p_tensor = buffered_tensors_iter->second.get(); } Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, @@ -412,6 +419,12 @@ common::Status SaveInitializedTensors( oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); return Status(st.Category(), st.Code(), oss.str()); } + + if (p_tensor != nullptr) { + // p_tensor was wrapped in a deleter by DeserializeTensorProto so we can simply release it here. + ORT_IGNORE_RETURN_VALUE(buffered_tensors_iter->second.release()); + buffered_tensors.erase(buffered_tensors_iter); + } } // 'name' is a reference to a string within the TensorProto that save_tensor_func may free diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 418e46924fb9f..9bbea279da82d 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -49,10 +49,27 @@ ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShape API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* this_ptr, +ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) { API_IMPL_BEGIN - this_ptr->shape = onnxruntime::TensorShape(dim_values, dim_count); + if (std::any_of(dim_values, dim_values + dim_count, [](int64_t v) { return v < -1; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "dim_values must be -1 (symbolic dimension) or larger."); + } + + auto num_dims = std::max(dim_count, info->dim_params.size()); + + // make shape and dim_values consistent + info->dim_params.resize(num_dims, ""); + + onnxruntime::TensorShapeVector dims; + dims.resize(num_dims, -1); + + for (size_t idx = 0; idx < dim_count; ++idx) { + dims[idx] = dim_values[idx]; + } + + info->shape = onnxruntime::TensorShape(dims); + return nullptr; API_IMPL_END } @@ -88,10 +105,22 @@ ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, ORT_API_STATUS_IMPL(OrtApis::SetSymbolicDimensions, _In_ struct OrtTensorTypeAndShapeInfo* info, _In_ const char** names, _In_ size_t dim_params_length) { + auto num_dims = std::max(info->shape.NumDimensions(), dim_params_length); + + // make shape and dim_values consistent + if (num_dims > info->shape.NumDimensions()) { + auto dim_values = info->shape.AsShapeVector(); + dim_values.resize(num_dims, -1); + info->shape = onnxruntime::TensorShape(dim_values); + } + info->dim_params.clear(); + info->dim_params.resize(num_dims, ""); + for (size_t idx = 0; idx < dim_params_length; ++idx) { - info->dim_params.push_back(names[idx]); + info->dim_params[idx] = names[idx]; } + return nullptr; } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 17c37b8882168..94a2a6677358e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -270,10 +270,15 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::str tensor_proto.set_raw_data(std::move(param)); } -void ConvertRawDataInTensorProto(TensorProto* tensor) { +void ConvertRawDataInTensorProto(TensorProto* tensor, + void* ext_data_buf, + size_t ext_data_len) { size_t element_size = 1; char* bytes = NULL; size_t num_elements = 0; + if (ext_data_buf && !ext_data_len) { + return; + } switch (tensor->data_type()) { case TensorProto_DataType_FLOAT: bytes = reinterpret_cast(tensor->mutable_float_data()->mutable_data()); @@ -337,6 +342,15 @@ void ConvertRawDataInTensorProto(TensorProto* tensor) { num_elements = (tensor->raw_data().size()) / element_size; bytes = const_cast(tensor->mutable_raw_data()->c_str()); } + + if (element_size == 1) { + return; + } + if (ext_data_buf) { + ORT_ENFORCE(ext_data_len % element_size == 0); + num_elements = ext_data_len / element_size; + bytes = reinterpret_cast(ext_data_buf); + } for (size_t i = 0; i < num_elements; ++i) { char* start_byte = bytes + i * element_size; char* end_byte = start_byte + element_size - 1; @@ -1317,22 +1331,15 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const auto* raw_data = tensor.DataRaw(); ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); - tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. // use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. auto offset = narrow(reinterpret_cast(raw_data)); - ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("location"); - entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); - entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("offset"); - entry->set_value(std::to_string(offset)); - entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("length"); - entry->set_value(std::to_string(tensor.SizeInBytes())); + ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, + offset, tensor.SizeInBytes(), tensor_proto); + } else { utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index f5dec7ae988f2..79eae48c10411 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -41,12 +41,18 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, ExternalDataInfo::PrepackedInfos* prepacked_infos = nullptr); /** * This function is used to convert the endianess of Tensor data. + * If ext_data_buf is provided, then this buffer content's endianess + * will be changed. * Mostly, will be used in big endian system to support the model file * generated on little endian system. - * @param initializer given initializer tensor + * @param tensor_proto given initializer tensor + * @param ext_data_buf optional externl data buffer + * @param ext_data_len optional externl data buffer lengeh * @returns None */ -void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto* initializer); +void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto* tensor_proto, + void* ext_data_buf = NULL, + size_t ext_data_len = 0); /** * Wrapper function for set_raw_data. diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e4915616b7b7c..39ffc6a5b0cee 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -7,30 +7,34 @@ #include #include #include -#include #include +#include -#include "core/common/common.h" #include + +#include "core/common/common.h" #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" +#include "core/framework/tensor_type_and_shape.h" #include "core/flatbuffers/schema/ort.fbs.h" -#include "core/framework/tensor_shape.h" #include "core/framework/tensor_external_data_info.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor_type_and_shape.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" +#include "core/graph/function_utils.h" #include "core/graph/graph_flatbuffers_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/model.h" +#include "core/graph/model_editor_api_types.h" #include "core/graph/model_load_utils.h" #include "core/graph/model_saving_options.h" #include "core/graph/node_attr_utils.h" #include "core/graph/op.h" #include "core/graph/runtime_optimization_record_container.h" -#include "core/graph/function_utils.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/graph/function.h" @@ -3500,6 +3504,10 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { #if !defined(DISABLE_SPARSE_TENSORS) sparse_tensor_names_.erase(tensor_name); #endif + + // doesn't matter if it existed or not + ORT_IGNORE_RETURN_VALUE(ortvalue_initializers_.erase(tensor_name)); + SetGraphResolveNeeded(); } else { #if !defined(DISABLE_SPARSE_TENSORS) @@ -3631,8 +3639,8 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( return Status::OK(); } -#endif // DISABLE_EXTERNAL_INITIALIZERS +#endif // DISABLE_EXTERNAL_INITIALIZERS #endif // !defined(ORT_MINIMAL_BUILD) bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto*& value) const { @@ -3645,6 +3653,16 @@ bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorPro return true; } +bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value) const { + auto it = ortvalue_initializers_.find(name); + if (it == ortvalue_initializers_.end()) { + return false; + } + + value = it->second; + return true; +} + void Graph::CleanAllInitializedTensors() noexcept { name_to_initial_tensor_.clear(); #if !defined(DISABLE_SPARSE_TENSORS) @@ -3660,6 +3678,8 @@ void Graph::CleanAllInitializedTensors() noexcept { delete graph_proto_->mutable_initializer()->ReleaseCleared(); } #endif + + ortvalue_initializers_.clear(); } const ONNX_NAMESPACE::TensorProto* Graph::GetConstantInitializer(const std::string& initializer_name, @@ -3709,13 +3729,14 @@ void Graph::AddValueInfo(const NodeArg* new_value_info) { value_info_.insert(new_value_info); } -std::vector Graph::CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, +template +std::vector Graph::CreateNodeArgs(const StringRange& names, const ArgNameToTypeMap& name_to_type_map) { const auto name_to_type_map_end = name_to_type_map.end(); std::vector results; results.reserve(names.size()); - for (auto& name : names) { + for (const std::string& name : names) { const TypeProto* type = nullptr; auto name_to_type_iter = name_to_type_map.find(name); @@ -4076,27 +4097,51 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { // This is used for constructing full path for external data // if it exists + auto add_initializer = [](TensorList& output_initializers, const TensorProto& initializer) -> void { + TensorProto& output = *output_initializers.Add(); + output = initializer; + + // copy any in-memory external data into raw data + if (utils::HasExternalData(initializer)) { + const std::filesystem::path ignored; + std::basic_string location; + onnxruntime::FileOffsetType file_offset; + SafeInt tensor_byte_size; + + ORT_THROW_IF_ERROR(utils::GetExternalDataInfo(initializer, ignored, location, file_offset, tensor_byte_size)); + + if (location == onnxruntime::utils::kTensorProtoMemoryAddressTag) { + // file_offset is address + void* data = reinterpret_cast(file_offset); + + // set in raw data + output.clear_data_location(); + output.set_raw_data(data, tensor_byte_size); + } + } + }; + + auto* mutable_initializers = result.mutable_initializer(); + #if !defined(DISABLE_SPARSE_TENSORS) const auto& model_path = ModelPath(); // We want to make sure that sparse initializers do not appear // as dense duplicates within the initializers list. - if (!sparse_tensor_names_.empty()) { - const auto sparse_end = sparse_tensor_names_.end(); - auto* mutable_initializer = result.mutable_initializer(); - for (const auto& initializer : graph_proto_->initializer()) { - if (sparse_end == sparse_tensor_names_.find(initializer.name())) { - *mutable_initializer->Add() = initializer; - } else { - auto& sparse_initializer = *result.add_sparse_initializer(); - auto status = utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer); - ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); - } + const bool has_sparse_initializers = !sparse_tensor_names_.empty(); + const auto sparse_end = sparse_tensor_names_.end(); + for (const auto& initializer : graph_proto_->initializer()) { + if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(initializer.name())) { + add_initializer(*mutable_initializers, initializer); + } else { + auto& sparse_initializer = *result.add_sparse_initializer(); + auto status = utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer); + ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); } - } else { - *result.mutable_initializer() = graph_proto_->initializer(); } #else - *result.mutable_initializer() = graph_proto_->initializer(); + for (const auto& initializer : graph_proto_->initializer()) { + add_initializer(*mutable_initializers, initializer); + } #endif return result; @@ -5345,6 +5390,9 @@ Status Graph::InlineFunction(Node& callnode) { } void Graph::SetInputs(gsl::span inputs) { + graph_inputs_including_initializers_.clear(); + graph_inputs_excluding_initializers_.clear(); + // creating graph from scratch // rely on SetGraphInputsOutputs() to fix up graph_inputs_excluding_initializers_ // if is_loaded_from_model_file_ == false @@ -5353,7 +5401,6 @@ void Graph::SetInputs(gsl::span inputs) { if (is_loaded_from_model_file_) { // graph loaded from model file - graph_inputs_excluding_initializers_.clear(); for (const auto* input : inputs) { ORT_ENFORCE(input->Exists(), "Input to set must exist."); if (name_to_initial_tensor_.find(input->Name()) == name_to_initial_tensor_.end()) { @@ -5370,6 +5417,7 @@ void Graph::SetInputs(gsl::span inputs) { } void Graph::SetOutputs(gsl::span outputs) { + graph_outputs_.clear(); graph_outputs_.reserve(outputs.size()); graph_outputs_.assign(outputs.begin(), outputs.end()); @@ -5688,4 +5736,207 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph return Status::OK(); } +#if !defined(ORT_MINIMAL_BUILD) +namespace { +ValueInfoProto OrtValueInfoToOnnx(const OrtValueInfo& vi) { + // the model builder API checks that the OrtValueInfo has a complete and valid OrtTypeInfo instance and that the + // name is not null/empty. + ORT_ENFORCE(vi.type_info->type == ONNX_TYPE_TENSOR, + "Internal error. Model Editor API should only allow OrtValueInfo for tensor to be created."); + + ValueInfoProto value_info_proto; + value_info_proto.set_name(vi.name); + + auto* tensor = value_info_proto.mutable_type()->mutable_tensor_type(); + const OrtTensorTypeAndShapeInfo& tensor_info = *vi.type_info->tensor_type_info.get(); + tensor->set_elem_type(tensor_info.type); + + auto& shape = *tensor->mutable_shape(); + + size_t idx = 0; + for (auto dim : tensor_info.shape.GetDims()) { + auto& dim_proto = *shape.add_dim(); + if (dim >= 0) { + dim_proto.set_dim_value(dim); + } else { + const std::string& dim_param = tensor_info.dim_params[idx]; + // if empty leave the new dim_proto with neither dim_value nor dim_param set. this represents an 'unknown' dim + if (!dim_param.empty()) { + dim_proto.set_dim_param(dim_param); + } + } + } + + return value_info_proto; +} +} // namespace + +Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updating_existing_graph) { + ArgNameToTypeMap name_to_type_map; + + // NOTE: need to create NodeArgs as we go along + + // add inputs first. the shape from an input for a non-const initializer is preferred, so we want to create the + // NodeArg for the value using that + + auto add_graph_inputs_outputs = [&, this]( + const InlinedVector>& graph_inputs_or_outputs, + bool is_input) { + // when updating a model we don't require the inputs or outputs to be set if they're unchanged. + if (updating_existing_graph && graph_inputs_or_outputs.empty()) { + return; + } + + std::vector node_args; + node_args.reserve(graph_inputs_or_outputs.size()); + for (auto& ort_value_info : graph_inputs_or_outputs) { + ValueInfoProto value_info = OrtValueInfoToOnnx(*ort_value_info); + + name_to_type_map[value_info.name()] = value_info.type(); + node_args.push_back(&GetOrCreateNodeArg(value_info.name(), &value_info.type())); + } + + if (is_input) { + SetInputs(node_args); + } else { + SetOutputs(node_args); + } + }; + + auto add_initializers = [this](const std::unordered_map>& initializers, + bool is_external) { + for (auto& name_and_ortvalue : initializers) { + // convert from OrtValue to TensorProto + const std::string& name = name_and_ortvalue.first; + OrtValue& v = *name_and_ortvalue.second; + + ORT_ENFORCE(v.IsTensor(), "Initializers must be Tensors"); + const Tensor& t = v.Get(); + TensorProto& tensor_proto = *graph_proto_->add_initializer(); + + tensor_proto.set_name(name); + tensor_proto.set_data_type(t.GetElementType()); + for (auto dim : t.Shape().GetDims()) { + tensor_proto.add_dims(dim); + } + + if (is_external) { + // pre-existing memory that we don't own. avoid a copy by storing the pointer in the ExternalDataInfo + const void* data_offset = t.DataRaw(); // address of memory not offset into file + auto offset = narrow(reinterpret_cast(data_offset)); + + ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, + offset, t.SizeInBytes(), tensor_proto); + + // add OrtValue to ortvalue_initializers_ to keep it alive and to store the deleter if provided. + ortvalue_initializers_.emplace(name, std::move(v)); + } else { + tensor_proto.set_raw_data(t.DataRaw(), t.SizeInBytes()); + } + + TypeProto type_proto{TypeProtoFromTensorProto(tensor_proto)}; + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(name, &type_proto)); + + name_to_initial_tensor_.emplace(name, &tensor_proto); + } + }; + + // process graph inputs first as we want the type/shape from them to be preferred if a graph input + // has a matching initializer + add_graph_inputs_outputs(api_graph.inputs, /*input*/ true); + + // add initializers + ortvalue_initializers_.reserve(api_graph.external_initializers.size()); + add_initializers(api_graph.external_initializers, /*is_external*/ true); + add_initializers(api_graph.initializers, /*is_external*/ false); + + // add graph outputs + add_graph_inputs_outputs(api_graph.outputs, /*input*/ false); + + // add nodes + for (const auto& ort_node : api_graph.nodes) { + const OrtNode& node = *ort_node; + + // convert Constant nodes to initializers + if (node.operator_name == "Constant" && node.domain_name == kOnnxDomain) { + // graph_proto_ provides storage + TensorProto& tensor = *graph_proto_->add_initializer(); + + // create NodeProto from OrtNode so we can use the existing conversion functions + NodeProto node_proto; + + // 'Constant' node has no inputs or attributes + ORT_RETURN_IF_NOT(node.input_names.empty() && node.attributes.size() == 1 && node.output_names.size() == 1, + node.node_name, + " is an invalid 'Constant' node. " + "Must have no inputs, one attribute and one output. "); + + node_proto.add_attribute()->CopyFrom(node.attributes[0]); + node_proto.add_output(node.output_names[0]); + + node_proto.set_op_type(node.operator_name); + node_proto.set_name(node.node_name); + node_proto.set_domain(node.domain_name); + + ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(node_proto, /*model_path*/ "", tensor)); + name_to_initial_tensor_.emplace(node.output_names[0], &tensor); + + continue; + } + + auto input_defs = CreateNodeArgs(node.input_names, name_to_type_map); + auto output_defs = CreateNodeArgs(node.output_names, name_to_type_map); + + const auto num_attributes = node.attributes.size(); + + NodeAttributes attributes; + attributes.reserve(num_attributes); + + for (const auto& attr : node.attributes) { + attributes[attr.name()] = attr; + } + + ORT_IGNORE_RETURN_VALUE(AddNode(node.node_name, node.operator_name, /*doc_string*/ "", + input_defs, output_defs, &attributes, node.domain_name)); + } + + return Resolve(); +} + +// static +Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, + const Model& owning_model, + const std::unordered_map& domain_to_version, + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + bool strict_shape_type_inference, + const logging::Logger& logger, + std::unique_ptr& graph) { + graph = std::make_unique(owning_model, + domain_to_version, + schema_registry, + /*parent_graph*/ nullptr, /*parent_node*/ nullptr, + logger, + strict_shape_type_inference); + + return graph->LoadFromModelEditorApiModel(api_graph); +} + +Status Graph::UpdateUsingModelEditorApiModel(const OrtModel& api_model) { + for (auto& entry : api_model.domain_to_version) { + if (auto it = domain_to_version_.find(entry.first); it != domain_to_version_.end()) { + if (it->second != entry.second) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Domain version can not be changed for '", entry.first, + "'. Current version: ", it->second); + } + } else { + domain_to_version_.insert(entry); + } + } + + // this will replace inputs/outputs and add nodes. + return LoadFromModelEditorApiModel(*api_model.graph, /*updating_existing_graph*/ true); +} + +#endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 922759b02e75f..199aa79cc1dde 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -300,8 +300,6 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init const auto* fbs_raw_data = fbs_tensor.raw_data(); if (fbs_raw_data) { if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) { - initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); - static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); const void* data_offset = fbs_raw_data->Data(); // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. @@ -309,15 +307,9 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. auto offset = narrow(reinterpret_cast(data_offset)); - ONNX_NAMESPACE::StringStringEntryProto* entry = initializer.mutable_external_data()->Add(); - entry->set_key("location"); - entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); - entry = initializer.mutable_external_data()->Add(); - entry->set_key("offset"); - entry->set_value(std::to_string(offset)); - entry = initializer.mutable_external_data()->Add(); - entry->set_key("length"); - entry->set_value(std::to_string(fbs_raw_data->size())); + ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, + offset, fbs_raw_data->size(), initializer); + } else { // fbs_raw_data is uint8_t vector, so the size is byte size initializer.set_raw_data(fbs_raw_data->Data(), fbs_raw_data->size()); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index be0531e6473fb..7629e40c1b5fe 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -7,6 +7,7 @@ #include "core/flatbuffers/flatbuffers_utils.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/model.h" +#include "core/graph/model_editor_api_types.h" #include "core/graph/model_load_utils.h" #ifdef _MSC_VER @@ -738,6 +739,36 @@ Status Model::Load(int fd, const PathString& model_path, std::shared_ptr& return Status::OK(); } +// static +common::Status Model::LoadFromModelEditorApiModel(const OrtModel& model_editor_api_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const ModelOptions& options, + const logging::Logger& logger, + std::unique_ptr& model) { + model = std::make_unique(); + model->model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + // The optimizer Initializer class requires a path if external data is used, however in the Graph API usage the + // external data is pointing to pre-allocated memory and does not require a path. Set a dummy value to make it happy. + model->model_path_ = std::filesystem::path("_GRAPH_API_MODEL_"); + + auto schema_registry = std::make_shared(); + if (local_registries != nullptr) { + for (const auto& schema_collection : *local_registries) { + schema_registry->RegisterRegistry(schema_collection); + } + } + + ORT_RETURN_IF_ERROR(Graph::LoadFromModelEditorApiModel(*model_editor_api_model.graph, + *model, + model_editor_api_model.domain_to_version, + schema_registry, + options.strict_shape_type_inference, + logger, + model->graph_)); + + return Status::OK(); +} + Status Model::Save(Model& model, int p_fd) { if (p_fd < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0."); @@ -917,5 +948,4 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, #endif return Status::OK(); } - } // namespace onnxruntime diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 2d2086aef41fd..6fd94c60d6b99 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -280,6 +280,12 @@ class Model { const logging::Logger& logger, const ModelOptions& options = {}); + static common::Status LoadFromModelEditorApiModel(const OrtModel& graph_api_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const ModelOptions& options, + const logging::Logger& logger, + std::unique_ptr& model); + common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& model) const; @@ -333,7 +339,7 @@ class Model { ModelMetaData model_metadata_; // Path to model file. May be empty. - const std::filesystem::path model_path_; + std::filesystem::path model_path_; // Main graph of the model. std::unique_ptr graph_; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h new file mode 100644 index 0000000000000..d72bd13093b61 --- /dev/null +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers_fwd.h" +#include "core/framework/ort_value.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/onnx_protobuf.h" + +// ORT C interface types for OrtGraphApi can't be in a namespace. +// We need to define them here so onnxruntime::Model can be created from OrtModel. + +struct OrtValueInfo { + std::string name; + std::unique_ptr type_info; +}; + +struct OrtOpAttr { + ONNX_NAMESPACE::AttributeProto attr_proto; +}; + +struct OrtNode { + std::string operator_name; + std::string domain_name; + std::string node_name; + + // OrtOpAttr is 1:1 with ONNX_NAMESPACE::AttributeProto currently. + // https://github.com/microsoft/onnxruntime/blob/bd5a759d0cdbed6e7f611c990d4eb5457a9ecf60/onnxruntime/core/session/standalone_op_invoker.cc#L318 + onnxruntime::InlinedVector attributes; + onnxruntime::InlinedVector input_names; + onnxruntime::InlinedVector output_names; + + // FUTURE if we need control flow nodes + // std::unordered_map subgraphs; +}; + +struct OrtGraph { + onnxruntime::InlinedVector> inputs; + onnxruntime::InlinedVector> outputs; + std::unordered_map> initializers; + std::unordered_map> external_initializers; + std::vector> nodes; +}; + +struct OrtModel { + std::unique_ptr graph; + std::unordered_map domain_to_version; +}; diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e755b4bfa6364..e36eef672c1ed 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -21,7 +21,16 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider, const ConfigOptions& config_options, const InlinedHashSet& compatible_execution_providers, const InlinedHashSet& excluded_initializers) noexcept - : GraphTransformer("ConstantFolding", compatible_execution_providers), + : ConstantFolding("ConstantFolding", execution_provider, skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers) { +} + +ConstantFolding::ConstantFolding(const std::string& name, + const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& compatible_execution_providers, + const InlinedHashSet& excluded_initializers) noexcept + : GraphTransformer(name, compatible_execution_providers), skip_dequantize_linear_(skip_dequantize_linear), config_options_(config_options), excluded_initializers_(excluded_initializers), @@ -144,7 +153,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, for (NodeIndex i : order) { auto* node = graph.GetNode(i); - if (!node) { + if (!node || !AllowConstantFolding(*node)) { continue; } diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h index 14eb2a9c5f06b..29bc67d560788 100644 --- a/onnxruntime/core/optimizer/constant_folding.h +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -28,6 +28,24 @@ class ConstantFolding : public GraphTransformer { const InlinedHashSet& compatible_execution_providers = {}, const InlinedHashSet& excluded_initializers = {}) noexcept; + protected: + /** + * Same as the constructor above but with a name provided by derived class. + */ + ConstantFolding(const std::string& name, + const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& compatible_execution_providers = {}, + const InlinedHashSet& excluded_initializers = {}) noexcept; + /** + * Derived class can implement this virtual function to limit the nodes that can be constant folded. + */ + virtual bool AllowConstantFolding(const Node& node) const { + ORT_UNUSED_PARAMETER(node); + return true; + } + private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.cc b/onnxruntime/core/optimizer/graph_optimizer_registry.cc new file mode 100644 index 0000000000000..8ede372470485 --- /dev/null +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.cc @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/graph_optimizer_registry.h" +#include "core/optimizer/graph_transformer_utils.h" +#include "core/optimizer/selection_and_optimization_func.h" +#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" + +using namespace onnxruntime; +using namespace ::onnxruntime::common; + +namespace onnxruntime { +#if !defined(ORT_MINIMAL_BUILD) +GraphOptimizerRegistry::GraphOptimizerRegistry(const onnxruntime::SessionOptions* sess_options, + const onnxruntime::IExecutionProvider* cpu_ep, + const logging::Logger* logger) : session_options_(sess_options), + cpu_ep_(cpu_ep), + logger_(logger) { + auto status = CreatePredefinedSelectionFuncs(); + ORT_ENFORCE(status.IsOK(), "Could not create pre-defined selection functions. Error Message: ", + status.ErrorMessage()); +} + +Status GraphOptimizerRegistry::CreatePredefinedSelectionFuncs() { + transformer_name_to_selection_func_[kConstantFoldingDQ] = ConstantFoldingDQFuncs::Select; + + return Status::OK(); +} + +std::optional GraphOptimizerRegistry::GetSelectionFunc(std::string& name) const { + auto lookup = transformer_name_to_selection_func_.find(name); + if (lookup != transformer_name_to_selection_func_.end()) { + return transformer_name_to_selection_func_.at(name); + } + LOGS(*logger_, WARNING) << "Can't find selection function of " << name; + return std::nullopt; +} +#else +GraphOptimizerRegistry::GraphOptimizerRegistry(const onnxruntime::SessionOptions* sess_options, + const onnxruntime::IExecutionProvider* cpu_ep, + const logging::Logger* logger) : session_options_(sess_options), + cpu_ep_(cpu_ep), + logger_(logger) {} + +std::optional GraphOptimizerRegistry::GetSelectionFunc(std::string& /*name*/) const { + return std::nullopt; +} +#endif +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.h b/onnxruntime/core/optimizer/graph_optimizer_registry.h new file mode 100644 index 0000000000000..15c9287c0eac8 --- /dev/null +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/common/logging/logging.h" +#include "core/common/common.h" +#include "core/optimizer/graph_transformer.h" +#include "core/framework/execution_providers.h" +#include "core/framework/compute_capability.h" + +namespace onnxruntime { +/** + * Optimizer's selection function: Selects a set of nodes from a given graph for optimization. Additional key/value strings can be provided to configure the optimizer. + * If needed, use graph_optimizer_registry to access the session options, the CPU EP and the logger. + * + * Optimizer's optimization function: Gets the nodes in ComputeCapability from nodes_to_optimize. Use graph_optimizer_registry to access the session options, the CPU EP + * and the logger if needed to create the optimizer. Run optimization on the nodes/subgraph, and finally, update the ComputeCapability. + * + */ +using KeyValueConfig = std::unordered_map; +using SelectionFunc = std::function>(const GraphViewer&, + const KeyValueConfig&, + const GraphOptimizerRegistry& graph_optimizer_registry)>; +using OptimizationFunc = std::function; + +/** + * A registration/lookup class for re-usable optimizers for EPs. + */ +class GraphOptimizerRegistry { + public: + /** + * The constructor takes in session options, the CPU EP and a logger as these are required by some optimizers. + */ + GraphOptimizerRegistry(const onnxruntime::SessionOptions* sess_options, + const onnxruntime::IExecutionProvider* cpu_ep, + const logging::Logger* logger); + + /** + * Get optimizer selection function. If the optimizer name can't be found, return nullopt. + */ + std::optional GetSelectionFunc(std::string& name) const; + + /** + * Get CPU EP. + */ + const onnxruntime::IExecutionProvider& GetCpuEp() const { return *cpu_ep_; } + + /** + * Get Session Options. + */ + const onnxruntime::SessionOptions& GetSessionOptions() const { return *session_options_; } + + /** + * Get Logger. + */ + const logging::Logger* GetLogger() const { return logger_; } + + private: + const onnxruntime::SessionOptions* session_options_; + const onnxruntime::IExecutionProvider* cpu_ep_; + const logging::Logger* logger_; + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + InlinedHashMap transformer_name_to_selection_func_; + + /** + * Create pre-defined selection functions. + */ + Status CreatePredefinedSelectionFuncs(); +#endif +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc new file mode 100644 index 0000000000000..a2f46d6ae693c --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" +#include "core/optimizer/graph_optimizer_registry.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +ConstantFoldingDQ::ConstantFoldingDQ(const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& node_index_set, + const InlinedHashSet& compatible_execution_providers, + const InlinedHashSet& excluded_initializers) noexcept + : ConstantFolding("ConstantFoldingDQ", execution_provider, skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers), + node_index_set_(node_index_set) {} + +bool ConstantFoldingDQ::AllowConstantFolding(const Node& node) const { + if (node_index_set_.find(node.Index()) != node_index_set_.end()) { + return true; + } + return false; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h new file mode 100644 index 0000000000000..7aed87fa06adb --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/constant_folding.h" +#include "core/framework/ort_value.h" +#include +#include "core/framework/execution_provider.h" + +namespace onnxruntime { + +/** +@class ConstantFoldingDQ + +It's the derived class from ConstantFolding. +*/ +class ConstantFoldingDQ : public ConstantFolding { + public: + /*! Constant folding will not be applied to nodes that have one of initializers from excluded_initializers as input. + \param execution_provider Execution provider instance to execute constant folding. + */ + ConstantFoldingDQ(const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& node_index_set, + const InlinedHashSet& compatible_execution_providers = {}, + const InlinedHashSet& excluded_initializers = {}) noexcept; + + bool AllowConstantFolding(const Node& node) const override; + + private: + InlinedHashSet node_index_set_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.cc b/onnxruntime/core/optimizer/selection_and_optimization_func.cc new file mode 100644 index 0000000000000..151c61952a631 --- /dev/null +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.cc @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "selection_and_optimization_func.h" +#include "core/graph/graph_utils.h" +#include "core/framework/compute_capability.h" +#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" + +namespace onnxruntime { + +std::vector> ConstantFoldingDQFuncs::Select(const GraphViewer& graph_viewer, + const KeyValueConfig& /*config*/, + const GraphOptimizerRegistry& /*graph_optimizer_registry*/) { + std::vector> result; + std::unique_ptr sub_graph = std::make_unique(); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED /*priority-based topological sort*/); + InitializedTensorSet constant_inputs; + const InlinedHashSet excluded_initializers; + + // Select DequantizeLinear node where all inputs are constant + for (const auto& index : node_index) { + const auto& node = graph_viewer.GetNode(index); + if (node->OpType() != "DequantizeLinear") { + continue; + } + if (!graph_utils::AllNodeInputsAreConstant(graph_viewer.GetGraph(), *node, constant_inputs, excluded_initializers)) { + continue; + } + sub_graph->nodes.push_back(index); + } + + result.push_back(std::make_unique(std::move(sub_graph))); + result.back()->optimization_func = ConstantFoldingDQFuncs::Optimize; + return result; +} + +Status ConstantFoldingDQFuncs::Optimize(Graph& graph, + const ComputeCapability& optimization_cc, + ComputeCapability& cc_to_update, + const GraphOptimizerRegistry& graph_optimizer_registry) { + std::string optimizer_name = kConstantFoldingDQ; + std::unordered_set original_initializers_to_remove; + std::unordered_set new_initializers_to_add; + InlinedHashSet dq_node_index_set; + + // iterate the nodes in node_to_optimize to: + // 1. get original initializers to remove + // 2. add new initializers + // 3. create dq node index set + for (const auto& index : optimization_cc.sub_graph->nodes) { + auto node = graph.GetNode(index); + if (node->OpType() != "DequantizeLinear") { + continue; + } + auto input_0 = node->InputDefs()[0]; + auto output_0 = node->OutputDefs()[0]; + original_initializers_to_remove.insert(input_0->Name()); + new_initializers_to_add.insert(output_0->Name()); + dq_node_index_set.insert(index); + } + + static auto transformer = std::make_unique(graph_optimizer_registry.GetCpuEp(), + false /*skip_dequantize_linear*/, + graph_optimizer_registry.GetSessionOptions().config_options, + dq_node_index_set); + + bool modified = false; + ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, *graph_optimizer_registry.GetLogger())); + + // update the overall ComputeCapability + std::vector updated_nodes; + for (auto index : cc_to_update.sub_graph->nodes) { + if (dq_node_index_set.find(index) != dq_node_index_set.end()) { + continue; + } + updated_nodes.push_back(index); + } + cc_to_update.sub_graph->nodes = updated_nodes; + + auto meta_def = cc_to_update.sub_graph->GetMutableMetaDef(); + std::vector updated_constant_initializers; + + for (auto constant_initializer : meta_def->constant_initializers) { + if (original_initializers_to_remove.find(constant_initializer) != original_initializers_to_remove.end()) { + continue; + } + updated_constant_initializers.push_back(constant_initializer); + } + + for (auto constant_initializer : new_initializers_to_add) { + updated_constant_initializers.push_back(constant_initializer); + } + + meta_def->constant_initializers = updated_constant_initializers; + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.h b/onnxruntime/core/optimizer/selection_and_optimization_func.h new file mode 100644 index 0000000000000..6ad62518833b0 --- /dev/null +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_optimizer_registry.h" +#include "core/framework/compute_capability.h" +#include "core/graph/graph_viewer.h" + +namespace onnxruntime { +static const std::string kConstantFoldingDQ = "ConstantFoldingDQ"; + +/** + * Optimizer's selection function: Selects a set of nodes from a given graph for optimization. Additional key/value strings can be provided to configure the optimizer. + * If needed, use graph_optimizer_registry to access the session options, the CPU EP and the logger. + * + * Optimizer's optimization function: Gets the nodes in ComputeCapability from nodes_to_optimize. Use graph_optimizer_registry to access the session options, the CPU EP + * and the logger if needed to create the optimizer. Run optimization on the nodes/subgraph, and finally, update the ComputeCapability. + * + */ + +struct ConstantFoldingDQFuncs { + static std::vector> Select(const GraphViewer& graph_viewer, + const KeyValueConfig& configs, + const GraphOptimizerRegistry& graph_optimizer_registry); + static Status Optimize(Graph& graph, + const ComputeCapability& optimization_cc, + ComputeCapability& cc_to_update, + const GraphOptimizerRegistry& graph_optimizer_registry); +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index ede476ff74d1b..def1d5e4b704c 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -153,6 +153,7 @@ std::shared_ptr ACLExecutionProvider::GetKernelRegistry() const std::vector> ACLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant*) const { std::vector> result; for (const auto& node : graph.Nodes()) { diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.h b/onnxruntime/core/providers/acl/acl_execution_provider.h index d635e56add30b..80e4aaaf021e3 100755 --- a/onnxruntime/core/providers/acl/acl_execution_provider.h +++ b/onnxruntime/core/providers/acl/acl_execution_provider.h @@ -39,6 +39,7 @@ class ACLExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; Status OnRunStart(const onnxruntime::RunOptions&) override; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 07e83933a890c..be09eefba791b 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1254,6 +1254,7 @@ GetSubGraphPartition(const std::vector& topological_order, const std: std::vector> CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant*) const { std::vector> result; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 5ff935463a1c1..f28ae77e49f83 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -56,6 +56,7 @@ class CANNExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; Status Compile(const std::vector& fused_nodes_and_graphs, diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index 3fa3868267c9b..cc7beed6bb298 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -39,6 +39,7 @@ CoreMLExecutionProvider::~CoreMLExecutionProvider() {} std::vector> CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 0609bf6af726d..574ae1fc0106b 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -20,6 +20,7 @@ class CoreMLExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index c65dd2a04bf55..b33b1f189594b 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -244,7 +244,7 @@ static Status ConcatenateCpuOutput(void* /*stream*/, // we can't easily use a C++ template for the tensor element type, // so use a span for some protection but work in bytes - gsl::span output_span = gsl::make_span(static_cast(output), + gsl::span output_span = gsl::make_span(static_cast(output), output_size_in_bytes); for (size_t i = 0, num_iterations = per_iteration_output.size(); i < num_iterations; ++i) { @@ -257,7 +257,7 @@ static Status ConcatenateCpuOutput(void* /*stream*/, " Expected:", per_iteration_shape, " Got:", iteration_data.Shape()); } - auto src = gsl::make_span(static_cast(iteration_data.DataRaw()), + auto src = gsl::make_span(static_cast(iteration_data.DataRaw()), bytes_per_iteration); auto dst = output_span.subspan(i * bytes_per_iteration, bytes_per_iteration); gsl::copy(src, dst); diff --git a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc index 03b39e19ed748..f3c6b18f8e753 100644 --- a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc +++ b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc @@ -34,17 +34,18 @@ ONNX_OPERATOR_KERNEL_EX( ConvInteger); Status ConvInteger::Compute(OpKernelContext* context) const { - size_t num_inputs = OpKernel::Node().InputDefs().size(); + const auto input_defs = Node().InputDefs(); + size_t num_inputs = input_defs.size(); const auto* X = context->Input(0); const auto* W = context->Input(1); uint8_t input_offset = 0; uint8_t filter_offset = 0; - if (num_inputs >= 3) { + if (num_inputs >= 3 && input_defs[2]->Exists()) { const auto* X_Zero_Point = context->Input(2); ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1."); input_offset = *(X_Zero_Point->Data()); } - if (num_inputs >= 4) { + if (num_inputs >= 4 && input_defs[3]->Exists()) { const auto* W_Zero_Point = context->Input(3); ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now."); filter_offset = *(W_Zero_Point->Data()); diff --git a/onnxruntime/core/providers/cuda/controlflow/loop.cc b/onnxruntime/core/providers/cuda/controlflow/loop.cc index 3295b73a800c9..d66de7c74e647 100644 --- a/onnxruntime/core/providers/cuda/controlflow/loop.cc +++ b/onnxruntime/core/providers/cuda/controlflow/loop.cc @@ -84,10 +84,10 @@ static Status ConcatenateGpuOutput(void* stream, std::vector& per_iter CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cur_output, iteration_data.DataRaw(), bytes_per_iteration, cudaMemcpyDeviceToDevice, static_cast(stream))); - cur_output = static_cast((static_cast(cur_output) + bytes_per_iteration)); + cur_output = static_cast((static_cast(cur_output) + bytes_per_iteration)); } - ORT_ENFORCE(static_cast(cur_output) - static_cast(output) == output_size_in_bytes, + ORT_ENFORCE(static_cast(cur_output) - static_cast(output) == output_size_in_bytes, "Concatenation did not fill output buffer as expected."); return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b675c08e5f804..54fb4429c0536 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2660,6 +2660,7 @@ std::unique_ptr CUDAExecutionProvider::GetDataTransf std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const { std::vector> result; const logging::Logger& logger = *GetLogger(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 79a48e7cb89e1..a75e81f1f0c6d 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -73,6 +73,7 @@ class CUDAExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; int GetDeviceId() const override { return info_.device_id; } diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index cbf745d3c7b4f..a38fe1efad540 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -290,16 +290,16 @@ Status Upsample::BaseCompute(OpKernelContext* context, scales_div[i] = fast_divmod(gsl::narrow_cast(ceil(scales[i]))); } - UpampleImpl(Stream(context), - mode_, - rank, - (UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0, - input_strides, - output_div_pitches, - scales_div, - reinterpret_cast(X->Data()), - reinterpret_cast(Y->MutableData()), - output_count); + UpsampleImpl(Stream(context), + mode_, + rank, + (UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0, + input_strides, + output_div_pitches, + scales_div, + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu b/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu index d1c2ae6332994..24aeada559979 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu @@ -8,12 +8,12 @@ namespace onnxruntime { namespace cuda { template -__global__ void _UpampleNearestKernel(const TArray input_pitches, - const TArray output_div_pitches, - const TArray scales_div, - const T* __restrict__ input_data, - T* __restrict__ output_data, - const size_t N) { +__global__ void _UpsampleNearestKernel(const TArray input_pitches, + const TArray output_div_pitches, + const TArray scales_div, + const T* __restrict__ input_data, + T* __restrict__ output_data, + const size_t N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG input_index = 0; CUDA_LONG output_index = id; @@ -36,13 +36,13 @@ __global__ void _UpampleNearestKernel(const TArray input_pitches, // This is the common use-case where the 4-D input (batched multi-channel images) // is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale] template -__global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2, - const TArray input_pitches, - const TArray output_div_pitches, - const TArray scales_div, - const T* __restrict__ input_data, - T* __restrict__ output_data, - const size_t N) { +__global__ void _UpsampleBilinear4DInputKernel(const int64_t input_dim2, + const TArray input_pitches, + const TArray output_div_pitches, + const TArray scales_div, + const T* __restrict__ input_data, + T* __restrict__ output_data, + const size_t N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG input_index = 0; @@ -95,13 +95,13 @@ __global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2, // The following method supports a 2-D input in 'Linear mode' template -__global__ void _UpampleBilinear2DInputKernel(const int64_t input_dim0, - const TArray input_pitches, - const TArray output_div_pitches, - const TArray scales_div, - const T* __restrict__ input_data, - T* __restrict__ output_data, - const size_t N) { +__global__ void _UpsampleBilinear2DInputKernel(const int64_t input_dim0, + const TArray input_pitches, + const TArray output_div_pitches, + const TArray scales_div, + const T* __restrict__ input_data, + T* __restrict__ output_data, + const size_t N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG input_index = 0; @@ -147,32 +147,32 @@ __global__ void _UpampleBilinear2DInputKernel(const int64_t input_dim0, } template -void UpampleImpl(cudaStream_t stream, - const onnxruntime::UpsampleMode upsample_mode, - const size_t rank, - const int64_t input_dim2, - const TArray& input_pitches, - const TArray& output_div_pitches, - const TArray& scales_div, - const T* input_data, - T* output_data, - const size_t N) { +void UpsampleImpl(cudaStream_t stream, + const onnxruntime::UpsampleMode upsample_mode, + const size_t rank, + const int64_t input_dim2, + const TArray& input_pitches, + const TArray& output_div_pitches, + const TArray& scales_div, + const T* input_data, + T* output_data, + const size_t N) { int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); if (onnxruntime::UpsampleMode::NN == upsample_mode) { if (rank == 4) { - _UpampleNearestKernel<<>>( + _UpsampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 3) { - _UpampleNearestKernel<<>>( + _UpsampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 2) { - _UpampleNearestKernel<<>>( + _UpsampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 1) { - _UpampleNearestKernel<<>>( + _UpsampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else { @@ -180,11 +180,11 @@ void UpampleImpl(cudaStream_t stream, } } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode) { if (rank == 4) { - _UpampleBilinear4DInputKernel<<>>( + _UpsampleBilinear4DInputKernel<<>>( input_dim2, input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 2) { - _UpampleBilinear2DInputKernel<<>>( + _UpsampleBilinear2DInputKernel<<>>( input_dim2, input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else { @@ -197,17 +197,17 @@ void UpampleImpl(cudaStream_t stream, } } -#define SPECIALIZED_IMPL(T) \ - template void UpampleImpl(cudaStream_t stream, \ - const onnxruntime::UpsampleMode upsample_mode, \ - const size_t rank, \ - const int64_t input_dim2, \ - const TArray& input_pitches, \ - const TArray& output_div_pitches, \ - const TArray& scales_div, \ - const T* input_data, \ - T* output_data, \ - const size_t N); +#define SPECIALIZED_IMPL(T) \ + template void UpsampleImpl(cudaStream_t stream, \ + const onnxruntime::UpsampleMode upsample_mode, \ + const size_t rank, \ + const int64_t input_dim2, \ + const TArray& input_pitches, \ + const TArray& output_div_pitches, \ + const TArray& scales_div, \ + const T* input_data, \ + T* output_data, \ + const size_t N); SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double) diff --git a/onnxruntime/core/providers/cuda/tensor/upsample_impl.h b/onnxruntime/core/providers/cuda/tensor/upsample_impl.h index 250ec6b272e34..fb47ad8301615 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/upsample_impl.h @@ -11,16 +11,16 @@ namespace onnxruntime { namespace cuda { template -void UpampleImpl(cudaStream_t stream, - const onnxruntime::UpsampleMode upsample_mode, - const size_t rank, - const int64_t input_dim2, - const TArray& input_pitches, - const TArray& output_div_pitches, - const TArray& scales_div, - const T* input_data, - T* output_data, - const size_t N); +void UpsampleImpl(cudaStream_t stream, + const onnxruntime::UpsampleMode upsample_mode, + const size_t rank, + const int64_t input_dim2, + const TArray& input_pitches, + const TArray& output_div_pitches, + const TArray& scales_div, + const T* input_data, + T* output_data, + const size_t N); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 9d23b8b950272..868b2103586f9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -93,12 +93,13 @@ namespace Dml ExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphOptimizerRegistry& graph_optimizer_registry, onnxruntime::IResourceAccountant* resource_accountant) const { #ifdef ENABLE_GRAPH_COMPILATION - return m_impl->GetCapability(graph, kernel_lookup, resource_accountant, *GetLogger()); + return m_impl->GetCapability(graph, kernel_lookup, graph_optimizer_registry, resource_accountant, *GetLogger()); #else - return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup, resource_accountant); + return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup, graph_optimizer_registry, resource_accountant); #endif } @@ -878,6 +879,7 @@ namespace Dml ExecutionProviderImpl::GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphOptimizerRegistry& /* graph_optimizer_registry */, onnxruntime::IResourceAccountant*, const onnxruntime::logging::Logger& logger) const { uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 7f420f8850001..aa3d8b0b4a409 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -13,6 +13,7 @@ namespace onnxruntime { class IResourceAccountant; +class GraphOptimizerRegistry; } namespace WRL { @@ -93,6 +94,7 @@ namespace Dml GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphOptimizerRegistry& graph_optimizer_registry, onnxruntime::IResourceAccountant* resource_accountant, const onnxruntime::logging::Logger& logger) const; @@ -288,6 +290,7 @@ namespace Dml std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphOptimizerRegistry& /* graph_optimizer_registry */, onnxruntime::IResourceAccountant* resource_accountant) const final override; onnxruntime::common::Status OnSessionInitializationEnd() override diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index 4da82b351f1d6..d0e5b0b1588ef 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -147,6 +147,7 @@ std::vector> DnnlExecutionProvider::GetSupportedNodes(con std::vector> DnnlExecutionProvider::GetCapability( const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // follow from coreml ep's Getcapability diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index bde18e139f2a3..8f951efef2a94 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -25,6 +25,7 @@ class DnnlExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, onnxruntime::IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 9d00436150286..d8e24ff1f5053 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -791,6 +791,7 @@ std::vector JsExecutionProvider::CreatePreferredAllocators() { std::vector> JsExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // `tenative_candidates` is a subset of `candidates`. diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 4bead50fc782e..c87303209c689 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -45,6 +45,7 @@ class JsExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 1558d22137c05..9a694b03387ae 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -993,6 +993,7 @@ GetPartitionedSubgraphs(const std::vector& topological_order, std::vector> MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; auto model = graph_viewer.CreateModel(*GetLogger()); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index d6af991f9b77e..7c89b5ec544a1 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -69,6 +69,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index 27bd584e2d3c6..28cfde817a620 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -81,6 +81,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {} std::vector> NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; const logging::Logger& logger = *GetLogger(); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h index ebf9372eb668d..a2269fdd89436 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h @@ -26,6 +26,7 @@ class NnapiExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 9d4ad88e2c2b3..d026ce386e5c3 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -647,7 +647,7 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe const auto& out_name = item.first; auto node = item.second; Ort::UnownedValue output_tensor = GetOutputTensor(context, - std::move(out_name), + out_name, subgraph_context_.output_names, node); auto mem_info = output_tensor.GetTensorMemoryInfo(); diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 12c16e9c9b8f6..6482a07ee92bc 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -107,6 +107,7 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { std::vector> OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index bbcca583b074b..020aec16e507c 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -51,6 +51,7 @@ class OpenVINOExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 3df231e53e7c0..d85277627a3de 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -198,35 +198,13 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, return Status::OK(); } -// Figure out the real context cache file path -// return true if context cache file exists -bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, - const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path) { - // always try the path set by user first, it's the only way to set it if load model from memory - if (!customer_context_cache_path.empty()) { - context_cache_path = ToPathString(customer_context_cache_path); - } else if (!model_pathstring.empty()) { // model loaded from file - if (is_qnn_ctx_model) { - // it's a context cache model, just use the model path - context_cache_path = model_pathstring; - } else if (!model_pathstring.empty()) { - // this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); - } - } - - return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); -} - Status CreateEPContextNodes(Model* model, unsigned char* buffer, uint64_t buffer_size, const std::string& sdk_build_version, const std::vector& fused_nodes_and_graphs, const QnnModelLookupTable& qnn_models, - const onnxruntime::PathString& context_cache_path, + const onnxruntime::PathString& context_model_path, bool qnn_context_embed_mode, uint64_t max_spill_fill_buffer_size, const logging::Logger& logger) { @@ -262,7 +240,19 @@ Status CreateEPContextNodes(Model* model, std::string cache_payload(buffer, buffer + buffer_size); ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload); } else { - onnxruntime::PathString context_bin_path = context_cache_path + ToPathString("_" + graph_name + ".bin"); + onnxruntime::PathString context_bin_path; + auto pos = context_model_path.find_last_of(ORT_TSTR(".")); + if (pos != std::string::npos) { + context_bin_path = context_model_path.substr(0, pos); + } else { + context_bin_path = context_model_path; + } + std::string graph_name_in_file(graph_name); + auto name_pos = graph_name_in_file.find_first_of(kQnnExecutionProvider); + if (name_pos != std::string::npos) { + graph_name_in_file.replace(name_pos, strlen(kQnnExecutionProvider), ""); + } + context_bin_path = context_bin_path + ToPathString(graph_name_in_file + ".bin"); std::string context_cache_name(std::filesystem::path(context_bin_path).filename().string()); std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary); if (!of_stream) { diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 3dfa0ae21001b..c54cd3ca6e90c 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -38,11 +38,6 @@ Status CreateNodeArgs(const std::vector& names, std::vector& node_args, onnxruntime::Graph& graph); -bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, - const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path); - Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, @@ -67,7 +62,7 @@ Status CreateEPContextNodes(Model* model, const std::string& sdk_build_version, const std::vector& fused_nodes_and_graphs, const std::unordered_map>& qnn_models, - const onnxruntime::PathString& context_cache_path, + const onnxruntime::PathString& context_model_path, bool qnn_context_embed_mode, uint64_t max_spill_fill_buffer_size, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index bcde69beceef7..26d792c008edc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -470,8 +470,10 @@ Status QnnBackendManager::InitializeProfiling() { QnnProfile_Level_t qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; if (ProfilingLevel::BASIC == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; + LOGS_DEFAULT(VERBOSE) << "Profiling level set to basic."; } else if (ProfilingLevel::DETAILED == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_DETAILED; + LOGS_DEFAULT(VERBOSE) << "Profiling level set to detailed."; } Qnn_ErrorHandle_t result = qnn_interface_.profileCreate(backend_handle_, qnn_profile_level, &profile_backend_handle_); ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to create QNN profile! Error: ", QnnErrorHandleToString(result)); diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index 1fb8742f724cd..cb92e927ff65a 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -181,7 +181,9 @@ void HtpSharedMemoryAllocator::Free(void* allocation_address) { // Avoid throwing exceptions as this may be running from a destructor. try { // take ownership of shared memory and free at end of scope - auto shared_memory = WrapSharedMemoryWithUniquePtr(allocation_address, rpcmem_lib_->Api()); + const size_t allocation_offset = AllocationOffsetFromStartOfHeader(); + void* raw_allocation_address = (void*)((std::byte*)allocation_address - allocation_offset); + auto shared_memory = WrapSharedMemoryWithUniquePtr(raw_allocation_address, rpcmem_lib_->Api()); // destroy header allocation_header.~AllocationHeader(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3fc537066ae0b..a5813dc2a4adc 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -195,6 +195,10 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio share_ep_contexts_ = config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; LOGS_DEFAULT(VERBOSE) << "User specified option - share EP contexts across sessions: " << share_ep_contexts_; + + stop_share_ep_contexts_ = + config_options->GetConfigOrDefault(kOrtSessionOptionStopShareEpContexts, "0") == "1"; + LOGS_DEFAULT(VERBOSE) << "User specified option - stop share EP contexts across sessions: " << stop_share_ep_contexts_; } static const std::string BACKEND_PATH = "backend_path"; @@ -384,17 +388,27 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } - qnn_backend_manager_ = qnn::QnnBackendManager::Create( - qnn::QnnBackendManagerConfig{backend_path, - profiling_level_etw, - profiling_level, - profiling_file_path, - context_priority, - qnn_saver_path, - device_id_, - htp_arch, - soc_model, - enable_htp_weight_sharing}); + // For context binary generation with weight sharing enabled, use the QnnBackendManager from the shared context if it exits + // So that all graphs from later sessions will be compiled into the same QNN context + if (context_cache_enabled_ && share_ep_contexts_ && SharedContext::GetInstance().GetSharedQnnBackendManager()) { + qnn_backend_manager_ = SharedContext::GetInstance().GetSharedQnnBackendManager(); + // Clear the QnnBackendManager from singleton to stop the resource share + if (stop_share_ep_contexts_) { + SharedContext::GetInstance().ResetSharedQnnBackendManager(); + } + } else { + qnn_backend_manager_ = qnn::QnnBackendManager::Create( + qnn::QnnBackendManagerConfig{backend_path, + profiling_level_etw, + profiling_level, + profiling_file_path, + context_priority, + qnn_saver_path, + device_id_, + htp_arch, + soc_model, + enable_htp_weight_sharing}); + } #if defined(_WIN32) if (onnxruntime::logging::EtwRegistrationManager::SupportsETW()) { @@ -655,6 +669,7 @@ static void PartitionCtxModel(const onnxruntime::GraphViewer& graph_viewer, std::vector> QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; @@ -904,25 +919,33 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); - onnxruntime::PathString context_cache_path; + onnxruntime::PathString context_model_path; bool is_ctx_file_exist = false; if (is_qnn_ctx_model || context_cache_enabled_) { const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); - is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model, - context_cache_path_cfg_, - graph_viewer_0.ModelPath().native(), - context_cache_path); + // Figure out the EP context model path from model path or session option + GetContextOnnxModelFilePath(context_cache_path_cfg_, + graph_viewer_0.ModelPath().native(), + context_model_path); } - ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_, - "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", - "Please remove the EP context model manually if you want to re-generate it."); - if (is_qnn_ctx_model) { // Get QnnModel from EP shared contexts if (share_ep_contexts_ && SharedContext::GetInstance().HasSharedQnnModels()) { @@ -965,7 +988,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); // Create QNN context from the cached binary, deserialize the QNN graph from the binary ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, - context_cache_path, + context_model_path, qnn_backend_manager_.get(), qnn_models, logger, @@ -1025,10 +1048,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_backend_manager_->GetSdkVersion(), fused_nodes_and_graphs, qnn_models_, - context_cache_path, + context_model_path, qnn_context_embed_mode_, max_spill_fill_buffer_size, logger)); + + if (share_ep_contexts_ && !stop_share_ep_contexts_ && + nullptr == SharedContext::GetInstance().GetSharedQnnBackendManager()) { + ORT_RETURN_IF_NOT(SharedContext::GetInstance().SetSharedQnnBackendManager(qnn_backend_manager_), + "Failed to set shared QnnBackendManager."); + } } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 31c34855ca4c0..d7a5d04d22692 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -31,6 +31,7 @@ class QNNExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; Status Compile(const std::vector& fused_nodes_and_graphs, @@ -90,6 +91,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t default_rpc_control_latency_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; + bool stop_share_ep_contexts_ = false; bool enable_spill_fill_buffer_ = false; #if defined(_WIN32) onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; diff --git a/onnxruntime/core/providers/qnn/shared_context.h b/onnxruntime/core/providers/qnn/shared_context.h index 81de357dbe677..277a484ad8528 100644 --- a/onnxruntime/core/providers/qnn/shared_context.h +++ b/onnxruntime/core/providers/qnn/shared_context.h @@ -61,13 +61,39 @@ class SharedContext { return graph_exist; } + bool SetSharedQnnBackendManager(std::shared_ptr& qnn_backend_manager) { + const std::lock_guard lock(mtx_); + + if (qnn_backend_manager_ != nullptr) { + if (qnn_backend_manager_ == qnn_backend_manager) { + return true; + } + return false; + } + qnn_backend_manager_ = qnn_backend_manager; + return true; + } + + std::shared_ptr GetSharedQnnBackendManager() { + const std::lock_guard lock(mtx_); + return qnn_backend_manager_; + } + + void ResetSharedQnnBackendManager() { + const std::lock_guard lock(mtx_); + qnn_backend_manager_.reset(); + } + private: SharedContext() = default; ~SharedContext() = default; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SharedContext); + // Used for passing through QNN models (deserialized from context binary) across sessions std::vector> shared_qnn_models_; + // Used for compiling multiple models into same QNN context binary + std::shared_ptr qnn_backend_manager_; // Producer sessions can be in parallel // Consumer sessions have to be after producer sessions initialized std::mutex mtx_; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc index 10fd81786f977..e9343e2b2e06a 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc @@ -51,6 +51,7 @@ std::vector> RknpuExecutionProvider::GetSupportedNodes( std::vector> RknpuExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // Find inputs, initializers and outputs for each supported subgraph std::vector> result; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h index ce16d63e111d9..75cae37d117a0 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h @@ -20,6 +20,7 @@ class RknpuExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 9d6e9df907ce3..49771488efc44 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -2441,6 +2441,7 @@ std::unique_ptr ROCMExecutionProvider::GetDataTransf std::vector> ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index ff2bff7c98723..2baaf2ff1a886 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -62,6 +62,7 @@ class ROCMExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const override { return info_.device_id; } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 6ff2572e5e668..9d61e1f12f5b6 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -200,6 +200,7 @@ struct SparseTensor; class TensorSeq; class SessionState; class ModelMetadefIdGenerator; +class GraphOptimizerRegistry; class If; class Loop; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 2dab9f6a402a0..90fd36ea29956 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -332,8 +332,9 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) const { - return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup, resource_accountant); + return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup, graph_optimizer_registry, resource_accountant); } common::Status IExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index a77f0cb4c27b0..83d615c1bde0a 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -105,6 +105,8 @@ using ModelMetaData = std::unordered_map; using IOnnxRuntimeOpSchemaCollectionPtr = std::shared_ptr; using IOnnxRuntimeOpSchemaRegistryList = std::list; using InitializedTensorSet = std::unordered_map; +using KeyValueConfig = std::unordered_map; +using SelectionFunc = std::function>(const GraphViewer&, const KeyValueConfig&, const GraphOptimizerRegistry&)>; struct Node__NodeIterator { virtual ~Node__NodeIterator() {} @@ -151,6 +153,10 @@ struct ConstGraphNodes_Iterator { struct ProviderHost { virtual const OrtApiBase* OrtGetApiBase() = 0; + virtual Status GetOptimizerByName(const std::string& name, + const GraphOptimizerRegistry& graph_optimizer_registry, + SelectionFunc& selection_func) = 0; + virtual void* HeapAllocate(size_t size) = 0; virtual void HeapFree(void*) = 0; @@ -253,6 +259,7 @@ struct ProviderHost { // IExecutionProvider virtual std::vector> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, const IExecutionProvider::IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) = 0; virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) = 0; @@ -627,6 +634,8 @@ struct ProviderHost { virtual std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) = 0; virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0; virtual std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability* p) = 0; + virtual void ComputeCapability__copy_optimization_func(ComputeCapability* p, ComputeCapability* selection_cc) = 0; + virtual void ComputeCapability__add_nodes_to_optimize(ComputeCapability* p, std::unique_ptr optimization_cc) = 0; // DataTransferManager virtual Status DataTransferManager__CopyTensor(const DataTransferManager* p, const Tensor& src, Tensor& dst) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index a502ce9c66f69..e2af144f455e4 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -527,6 +527,9 @@ struct ComputeCapability final { std::unique_ptr& SubGraph() { return g_host->ComputeCapability__SubGraph(this); } + void copy_optimization_func(ComputeCapability* selection_cc) { g_host->ComputeCapability__copy_optimization_func(this, selection_cc); } + void add_nodes_to_optimize(std::unique_ptr optimization_cc) { g_host->ComputeCapability__add_nodes_to_optimize(this, std::move(optimization_cc)); } + ComputeCapability() = delete; ComputeCapability(const ComputeCapability&) = delete; void operator=(const ComputeCapability&) = delete; diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc index c7fc6d3a556a7..4eae7c97f9ab0 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc @@ -72,6 +72,7 @@ SNPEExecutionProvider::~SNPEExecutionProvider() {} std::vector> SNPEExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector candidates; for (auto& node_index : graph.GetNodesInTopologicalOrder()) { diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.h b/onnxruntime/core/providers/snpe/snpe_execution_provider.h index 99033649fcbbf..4b7987b38ee93 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.h +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.h @@ -19,6 +19,7 @@ class SNPEExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index e59d252793532..523ebbfae807a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2459,6 +2459,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& std::vector> TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* /* resource_accountant */) const { // Construct subgraph capability from node list std::vector> result; @@ -2664,11 +2665,61 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } } + /** + * Enable EP related L2+ graph optimizations: + * + * 1. Calls provider bridge API to lookup pre-defined optimizer by name and get selection function. + * - Example: g_host->GetOptimizerByName(optimizer_name, graph_optimizer_registry, selection_func) + * 2. Executes the selection function to obtain the selection ComputeCapability. + * - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization. + * 3. Uses the selection ComputeCapability to create the optimization ComputeCapability. + * 4. Returns the final ComputeCapability, with nodes_to_optimize set to the optimization ComputeCapability. + * + * Current available optimizations: + * - (ConstantFoldingDQ) constant folding on DQ nodes, i.e. dequantize INT32, UINT16, INT16 constant to FP32. + */ + + SelectionFunc selection_func; + std::vector> selection_cc; + + // Prepare for ConstantFoldingDQ optimizer + // Note: The NodeIndex here is the node index in the graph, not the index in node vector in supported_nodes_vector. + std::unordered_set trt_selection_node_set; // The qualified dq nodes selected by TRT EP + std::unordered_map consumer_to_dq; // consumer node -> dq node + + if (dla_enable_) { + std::string optimizer_name = "ConstantFoldingDQ"; + const std::unordered_map key_value_config; + auto status = g_host->GetOptimizerByName(optimizer_name, graph_optimizer_registry, selection_func); + if (status == Status::OK()) { + if (selection_func) { + selection_cc = selection_func(graph, key_value_config, graph_optimizer_registry); + SelectQualifiedDQNode(graph, trt_selection_node_set, consumer_to_dq); + } + } else { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Can't get optimizer " << optimizer_name; + } + } + + // Create ComputeCapability int number_of_trt_nodes = 0, subgraph_index = 0; - for (const auto& group : supported_nodes_vector) { + for (auto& group : supported_nodes_vector) { if (!group.first.empty()) { + if (!selection_cc.empty()) { + // Include DQ nodes that are filtered out by TRT parser + UpdateSupportedNodeVectorForDQ(graph, group, supported_nodes_vector, consumer_to_dq); + } + std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); + auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); + + // add optimization ComputeCapability to node_to_optimize + for (auto& cc : selection_cc) { + std::unique_ptr optimization_cc = CreateOptimizationComputeCapability(cc.get(), trt_selection_node_set, compute_capability.get()); + compute_capability->add_nodes_to_optimize(std::move(optimization_cc)); + } + + result.push_back(std::move(compute_capability)); number_of_trt_nodes += static_cast(group.first.size()); subgraph_index++; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 873826a81c51b..934cc06eed45f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -249,6 +249,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return device_id_; } @@ -592,5 +593,35 @@ class TensorrtExecutionProvider : public IExecutionProvider { * This function only creates the instance at the first time it's being called." */ nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; + + /** + * This is the helper function for ConstantFoldingDQ graph transformer. + * + * It selects the qualified/required DQ node to be optimized as well as provides a mapping table + * to help TRT EP later include the DQ node which is filtered out by TRT parser. + */ + void SelectQualifiedDQNode(const GraphViewer& graph, + std::unordered_set& selection_node_set, + std::unordered_map& consumer_to_dq) const; + + /** + * This function returns an optimization ComputeCapability that is limited to: + * 1. the DQ nodes in this individual TRT ComputeCapability + * 2. the DQ nodes that are qualified and selected by TRT EP + * + * It also needs to make sure the DQ nodes is a subset of the complete list of DQ nodes to optimize in original selection ComputeCapability. + * Finally, copy the optimization function from the original selection ComputeCapability. + */ + std::unique_ptr CreateOptimizationComputeCapability(ComputeCapability* selection_cc, + std::unordered_set& trt_selection_node_set, + ComputeCapability* trt_cc) const; + /** + * This function helps add back the DQ nodes that are filtered out by TRT parser. + * The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization. + */ + void UpdateSupportedNodeVectorForDQ(const GraphViewer& graph, + SubGraph_t& supported_node_vector, + SubGraphCollection_t& supported_nodes_vector, + std::unordered_map consumer_to_dq) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index 92fa101118506..71674f7c9c557 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -258,4 +258,133 @@ void TensorrtExecutionProvider::SetAllGraphInputs(Graph& graph) const { graph.SetInputs(graph_inputs_including_initializers); } + +/** + * This is the helper function for ConstantFoldingDQ graph transformer. + * + * It selects the qualified/required DQ node to be optimized as well as provides a mapping table + * to help TRT EP later include the DQ node which is filtered out by TRT parser. + */ +void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, + std::unordered_set& selection_node_set, + std::unordered_map& consumer_to_dq) const { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Select qualified DQ nodes ..."; + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); + for (auto index : node_index) { + auto* node = graph.GetNode(index); + if (!node) { + continue; + } + + const auto* input_def = node->InputDefs()[0]; // Get NodeArg of the initializer of the DequantizeLinear node; + auto data_type = input_def->TypeAsProto()->tensor_type().elem_type(); + auto constant_initializer = graph.IsConstantInitializer(input_def->Name(), true); + + // Node selection: (i.e. initializer -> DQ -> bias of X) + // 1. DequantizeLinear op + // 2. DQ node does not produce graph output, single consumer + // 3. The first input of DQ is constant initializer. + // 4. The data type of initializer is INT32, UINT16 or INT16 + // 5. X should be Gemm, Conv or LayerNormalization ? + if (node->OpType() == "DequantizeLinear" && + node->GetOutputEdgesCount() == 1 && + (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16 || data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) && + constant_initializer) { + const Node& consumer_node = *node->OutputNodesBegin(); + selection_node_set.insert(index); + consumer_to_dq[consumer_node.Index()] = index; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " <- " << node->Name(); + } + } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Total " << selection_node_set.size() << " DequantizeLinear node(s) are selected."; +} + +/** + * This function returns an optimization ComputeCapability that is limited to: + * 1. the DQ nodes in this individual TRT ComputeCapability + * 2. the DQ nodes that are qualified and selected by TRT EP + * + * It also needs to make sure the DQ nodes is a subset of the complete list of DQ nodes to optimize in original selection ComputeCapability. + * Finally, copy the optimization function from the original selection ComputeCapability. + */ +std::unique_ptr TensorrtExecutionProvider::CreateOptimizationComputeCapability(ComputeCapability* selection_cc, + std::unordered_set& trt_selection_node_set, + ComputeCapability* trt_cc) const { + auto sub_graph = onnxruntime::IndexedSubGraph::Create(); + std::unordered_set selection_node_set; + + for (auto index : selection_cc->SubGraph()->Nodes()) { + selection_node_set.insert(index); + } + + for (auto index : trt_cc->SubGraph()->Nodes()) { + if (selection_node_set.find(index) == selection_node_set.end()) { + continue; + } + if (trt_selection_node_set.find(index) == trt_selection_node_set.end()) { + continue; + } + sub_graph->Nodes().push_back(index); + } + auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); + compute_capability->copy_optimization_func(selection_cc); + return compute_capability; +} + +/** + * This function helps add back the DQ nodes that are filtered out by TRT parser. + * The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization. + */ +void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph, + SubGraph_t& supported_node_vector, + SubGraphCollection_t& supported_nodes_vector, + std::unordered_map consumer_to_dq) const { + if (consumer_to_dq.empty()) { + return; + } + + if (!supported_node_vector.second) { + return; + } + + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); + auto supported_nodes = supported_node_vector.first; + for (auto index : supported_nodes) { + if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) { + continue; + } + + auto dq_node_index = consumer_to_dq[node_index[index]]; + + // Check if DQ node is included in one of the subgraphs + auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool { + for (auto& node_vector : supported_nodes_vector) { + if (!node_vector.second) { + continue; + } + for (auto i : node_vector.first) { + if (node_index[i] == node_idx) { + return true; + } + } + } + return false; + }; + + // If the DQ node is already in the subgraph, do nothing. + if (in_the_subgraph_collection(dq_node_index)) { + continue; + } + + // Find the iterator pointing to the target element + auto it = std::find(node_index.begin(), node_index.end(), dq_node_index); + if (it != node_index.end()) { + // Calculate the index + size_t idx = std::distance(node_index.begin(), it); + supported_node_vector.first.push_back(idx); + auto node = graph.GetNode(dq_node_index); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser."; + } + } +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 5d2204b0b1979..ab8a95b38491d 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -51,7 +51,7 @@ const InlinedVector VitisAIExecutionProvider::GetEpContextNodes() c return ep_context_node_ptrs; } std::vector> VitisAIExecutionProvider::GetCapability( - const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, IResourceAccountant* /* resource_accountant */) const { + const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { if (graph_viewer.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 5b031ab882839..f72f8cc721fbd 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -29,6 +29,7 @@ class VitisAIExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return 0; } diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 4b9f6fae86423..3b5daef04dd50 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -62,6 +62,7 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {} std::vector> VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h index 16cfbc8a9c581..1c0b8b63a8e6c 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h @@ -40,6 +40,7 @@ class VSINPUExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; Status Compile(const std::vector& fused_nodes_and_graphs, diff --git a/onnxruntime/core/providers/webgpu/external_data_loader.cc b/onnxruntime/core/providers/webgpu/external_data_loader.cc new file mode 100644 index 0000000000000..6da9598b146f5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/external_data_loader.cc @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(__wasm__) + +#include + +#include "core/framework/tensor.h" +#include "core/providers/webgpu/external_data_loader.h" + +namespace onnxruntime { +namespace webgpu { + +bool ExternalDataLoader::CanLoad(const OrtMemoryInfo& target_memory_info) const { + return target_memory_info.device.Type() == OrtDevice::CPU || + (target_memory_info.device.Type() == OrtDevice::GPU && target_memory_info.name == WEBGPU_BUFFER); +} + +common::Status ExternalDataLoader::LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const { + ExternalDataLoadType load_type; + if (tensor.Location().device.Type() == OrtDevice::CPU) { + load_type = ExternalDataLoadType::CPU; + } else if (tensor.Location().device.Type() == OrtDevice::GPU && + tensor.Location().name == WEBGPU_BUFFER) { + load_type = ExternalDataLoadType::WEBGPU_BUFFER; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: ", tensor.Location().ToString()); + } + + return LoadWebAssemblyExternalData(env, data_file_path, data_offset, data_length, load_type, tensor.MutableDataRaw()); +} + +} // namespace webgpu +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/core/providers/webgpu/external_data_loader.h b/onnxruntime/core/providers/webgpu/external_data_loader.h new file mode 100644 index 0000000000000..7ced4e930bf7a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/external_data_loader.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(__wasm__) + +#include "core/framework/external_data_loader.h" + +namespace onnxruntime { +namespace webgpu { + +class ExternalDataLoader : public IExternalDataLoader { + public: + ExternalDataLoader() {}; + ~ExternalDataLoader() {}; + + bool CanLoad(const OrtMemoryInfo& target_memory_info) const override; + + common::Status LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/core/providers/webgpu/generator/range.cc b/onnxruntime/core/providers/webgpu/generator/range.cc index a0b65f08a5b4e..99c5a1c1b5566 100644 --- a/onnxruntime/core/providers/webgpu/generator/range.cc +++ b/onnxruntime/core/providers/webgpu/generator/range.cc @@ -23,7 +23,7 @@ Status Range::ComputeInternal(ComputeContext& context) const { return Status::OK(); } - uint32_t output_size = gsl::narrow(n); + uint32_t output_size = onnxruntime::narrow(n); RangeProgram program{}; #if defined(__GNUC__) #pragma GCC diagnostic push diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 75866513e2c7d..8a22e45f17047 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -141,7 +141,7 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { } } - uint32_t vec_size = gsl::narrow((size + 3) / 4); + uint32_t vec_size = onnxruntime::narrow((size + 3) / 4); BinaryElementwiseProgram program{kernel_name_, expression_, is_broadcast, diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc new file mode 100644 index 0000000000000..d06fc5a57eb8c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -0,0 +1,238 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/inlined_containers.h" +#include "core/providers/common.h" +#include "core/providers/webgpu/math/softmax.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Softmax, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Softmax); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Softmax, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Softmax); + +ONNX_OPERATOR_KERNEL_EX( + Softmax, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Softmax); + +static std::string MaxVector(const std::string& name, int components) { + switch (components) { + case 1: + return name; + case 2: + return "max(" + name + ".x, " + name + ".y)"; + case 3: + return "max(max(" + name + ".x, " + name + ".y), " + name + ".z)"; + case 4: + return "max(max(" + name + ".x, " + name + ".y), max(" + name + ".z, " + name + ".w))"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +static std::string SumVector(const std::string& x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +static int GetMaxComponents(int64_t size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Add input and output variables + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + int components = input.NumComponents(); + + const std::string thread_max_decl = is_fp32_ + ? "var thread_max = x_value_t(-3.402823e+38f);\n" + : "var thread_max = x_value_t(-65504.0h);\n"; + + // Define shared memory for row max and row sum + shader.AdditionalImplementation() + << "var row_max_shared : x_value_t;\n" + << "var row_sum_shared : x_value_t;\n" + << "var thread_shared : array;\n"; + + // Define helper functions to get and set values + shader.AdditionalImplementation() + << "fn getValue(row: i32, col: i32, row_stride: i32) -> x_value_t {\n" + << " let index = row * row_stride + col;\n" + << " return x[index];\n" + << "}\n" + << "fn setValue(row: i32, col: i32, row_stride: i32, value: x_value_t) {\n" + << " let index = row * row_stride + col;\n" + << " result[index] = value;\n" + << "}\n"; + + // Main function body + shader.MainFunctionBody() + << " let gindex = i32(global_idx);\n" + << " let lindex = i32(local_idx);\n" + << " const wg = " << wg_ << ";\n" + << " let row = gindex / wg;\n" + << " let cols = uniforms.packedCols;\n" + << " let row_stride : i32 = uniforms.packedCols;\n" + + // Find the row's max value + << thread_max_decl + << " for (var col = lindex; col < cols; col += wg) {\n" + << " let value = getValue(row, col, row_stride);\n" + << " thread_max = max(thread_max, value);\n" + << " }\n" + << " if (lindex < cols) {\n" + << " thread_shared[lindex] = thread_max;\n" + << " }\n" + << " workgroupBarrier();\n" + + // Reduce to find the max value + << " var reduce_size = min(cols, wg);\n" + << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (lindex < curr_size) {\n" + << " thread_shared[lindex] = max(thread_shared[lindex], thread_shared[lindex + reduce_size]);\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + << " if (lindex == 0) {\n" + << " row_max_shared = x_value_t(" << MaxVector("thread_shared[0]", components) << ");\n" + << " }\n" + << " workgroupBarrier();\n" + + // Find the row's sum of exponentials + << " var thread_sum = x_value_t(0.0);\n" + << " for (var col = lindex; col < cols; col += wg) {\n" + << " let sub_exp = exp(getValue(row, col, row_stride) - row_max_shared);\n" + << " thread_sum += sub_exp;\n" + << " }\n" + << " thread_shared[lindex] = thread_sum;\n" + << " workgroupBarrier();\n" + + // Reduce to find the sum of exponentials + << " for (var curr_size = wg >> 1; curr_size > 0; curr_size = curr_size >> 1) {\n" + << " if (lindex < curr_size) {\n" + << " thread_shared[lindex] = thread_shared[lindex] + thread_shared[lindex + curr_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + << " if (lindex == 0) {\n" + << " row_sum_shared = x_value_t(" << SumVector("thread_shared[0]", components) << ");\n" + << " }\n" + << " workgroupBarrier();\n" + + // Calculate the final value for each element in the row + << " for (var col = lindex; col < cols; col += wg) {\n" + << " let value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n" + << " setValue(row, col, row_stride, value);\n" + << " }\n"; + + return Status::OK(); +} + +Status Softmax::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + size_t input_rank = input_shape.NumDimensions(); + auto* output_tensor = context.Output(0, input_shape); + + // normalize axis + size_t axis = static_cast(HandleNegativeAxis(axis_, input_rank)); + bool is_transpose_required = axis < input_rank - 1; + + TensorShape transposed_input_shape; + Tensor transposed_input_tensor; + Tensor intermediate_output; + InlinedVector perm(input_rank); + + if (is_transpose_required) { + std::iota(std::begin(perm), std::end(perm), 0); + perm[axis] = input_rank - 1; + perm[input_rank - 1] = axis; + + TensorShapeVector transposed_input_dims; + for (auto e : perm) { + transposed_input_dims.push_back(input_shape[e]); + } + + transposed_input_shape = TensorShape(transposed_input_dims); + transposed_input_tensor = context.CreateGPUTensor(input_tensor->DataType(), transposed_input_shape); + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, *input_tensor, transposed_input_tensor)); + intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape); + } + + const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1]; + const int64_t rows = input_shape.Size() / cols; + const int64_t components = GetMaxComponents(cols); + const auto packed_cols = cols / components; + uint32_t workgroup_size = rows == 1 ? 256 : 64; + // check input tensor element type is float + const bool is_fp32 = input_tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + + SoftmaxProgram program{workgroup_size, is_fp32}; + if (is_transpose_required) { + program + .AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}) + .AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}); + } else { + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}); + } + + program + .CacheHint(std::to_string(components), std::to_string(workgroup_size)) + .SetWorkgroupSize(workgroup_size) + .SetDispatchGroupSize(static_cast(rows)) + .AddUniformVariables({{static_cast(packed_cols)}}); + + ORT_RETURN_IF_ERROR(context.RunProgram(program)); + + // If transpose was required, transpose the result back + if (is_transpose_required) { + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, intermediate_output, *output_tensor)); + } + + return Status::OK(); +} +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/softmax.h b/onnxruntime/core/providers/webgpu/math/softmax.h new file mode 100644 index 0000000000000..cc97611dcb4bc --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/softmax.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class Softmax final : public WebGpuKernel { + public: + Softmax(const OpKernelInfo& info) : WebGpuKernel{info} { + int opset_ = info.node().SinceVersion(); + int64_t axis; + Status status = info.GetAttr("axis", &axis); + + if (status.IsOK()) { + axis_ = axis; + } else { + if (opset_ < 13) { + axis_ = 1; // opset-12 and below, the default axis value is 1 + } else { + axis_ = -1; // opset-13, the default axis value is -1 + } + } + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int64_t axis_; +}; + +class SoftmaxProgram final : public Program { + public: + SoftmaxProgram(uint32_t wg, bool is_fp32) + : Program{"Softmax"}, wg_{wg}, is_fp32_{is_fp32} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32}); + + private: + uint32_t wg_; + bool is_fp32_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index eaaad206ebaf5..189d7baafce6a 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -27,7 +27,7 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { if (size == 0) { return Status::OK(); } - uint32_t vec_size = gsl::narrow((size + 3) / 4); + uint32_t vec_size = onnxruntime::narrow((size + 3) / 4); UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program .AddInputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 64172021e82f1..28ad686909a47 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -23,7 +23,7 @@ static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { if (axis < -rank && axis >= rank) { ORT_THROW("invalid axis: ", axis); } - return gsl::narrow(axis < 0 ? axis + rank : axis); + return onnxruntime::narrow(axis < 0 ? axis + rank : axis); } static std::string SumVector(std::string x, int components) { @@ -92,10 +92,10 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); - const uint32_t norm_count = gsl::narrow(x_shape.SizeToDimension(axis)); + const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis)); const int64_t norm_size = x_shape.SizeFromDimension(axis); const int components = GetMaxComponents(norm_size); - const uint32_t norm_size_vectorized = gsl::narrow((norm_size + components - 1) / components); + const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); const auto scale_size = scale->Shape().Size(); const auto bias_size = (bias) ? bias->Shape().Size() : 0; diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index d1d4c242c4697..976b7927ac3dd 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -206,6 +206,26 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp } } +std::ostream& operator<<(std::ostream& os, ValidationMode mode) { + switch (mode) { + case ValidationMode::Disabled: + os << "Disabled"; + break; + case ValidationMode::WGPUOnly: + os << "WGPUOnly"; + break; + case ValidationMode::Basic: + os << "Basic"; + break; + case ValidationMode::Full: + os << "Full"; + break; + default: + os << "Unknown(" << static_cast(mode) << ")"; + } + return os; +} + namespace { TensorShape GetReducedShape(const TensorShape& shape, int component /* > 1 */) { ORT_ENFORCE(shape.NumDimensions() > 0 && shape.GetDims()[shape.NumDimensions() - 1] % component == 0, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 7bfd9e8800099..95fef36144025 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -237,6 +237,7 @@ enum class ValidationMode { Basic, Full }; +std::ostream& operator<<(std::ostream& os, ValidationMode mode); namespace details { class ProgramWrapper; diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index 1fdd312d4f0d8..7a4a873a1adf3 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -24,14 +24,14 @@ Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint auto limit_per_dimension = limits_.maxComputeWorkgroupsPerDimension; if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) { - auto size = static_cast(x) * static_cast(y) * static_cast(z); - uint32_t dispatch_avg = gsl::narrow(std::ceil(std::sqrt(size))); + double size = static_cast(x) * static_cast(y) * static_cast(z); + double dispatch_avg = std::ceil(std::sqrt(size)); if (dispatch_avg > limit_per_dimension) { - dispatch_avg = gsl::narrow(std::ceil(std::cbrt(size))); + dispatch_avg = std::ceil(std::cbrt(size)); ORT_RETURN_IF(dispatch_avg > limit_per_dimension, "The dispatch group size exceeds WebGPU maximum."); - x = y = z = dispatch_avg; + x = y = z = static_cast(dispatch_avg); } else { - x = y = dispatch_avg; + x = y = static_cast(dispatch_avg); z = 1; } } diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc new file mode 100644 index 0000000000000..eb7903e7903b6 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc @@ -0,0 +1,168 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/reduction/reduction_ops.h" +#include +#include "core/framework/data_transfer_manager.h" +#include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +#define REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceOp, begin, end) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + ReduceOp, \ + kOnnxDomain, \ + begin, end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + ReduceOp); + +#define REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceOp, version) \ + ONNX_OPERATOR_KERNEL_EX( \ + ReduceOp, \ + kOnnxDomain, \ + version, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()).InputMemoryType(OrtMemTypeCPUInput, 1), \ + ReduceOp); + +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 13, 17); +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMean, 18); + +Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + bool reduce_on_all_axes = no_op_with_empty_axes_ == false && axes_.empty(); + std::string loop_header = code_[0]; + std::string loop_body = "let current_element: input_value_t = " + input.GetByIndices("input_indices") + ";\n" + code_[1]; + std::string loop_footer = code_[2]; + const auto input_rank = input.Rank(); + for (int i = 0, l = 0; i < input_rank; ++i) { + if (reduce_on_all_axes || std::find(axes_.begin(), axes_.end(), i) != axes_.end()) { + if (keepdims_) { + l++; + } + std::stringstream ss; + std::string index = "i" + std::to_string(i); + ss << "for (var " << index << " : u32 = 0; " << index << " < " << input.IndicesGet("uniforms.input_shape", i) << "; " << index << "++) {\n"; + ss << input.IndicesSet("input_indices", i, index) << ";\n"; + ss << loop_body << "\n"; + ss << "}\n"; + loop_body = ss.str(); + } else { + std::stringstream ss; + ss << loop_header << "\n"; + std::string index = "i" + std::to_string(i); + ss << "let " << index << " = " << output.IndicesGet("output_indices", l) << ";\n"; + ss << input.IndicesSet("input_indices", i, index) << ";\n"; + loop_header = ss.str(); + l++; + } + } + std::stringstream input_indices_init_value; + for (int i = 0; i < input_rank - 1; ++i) { + input_indices_init_value << "0, "; + } + input_indices_init_value << "0"; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices: output_indices_t = " << output.OffsetToIndices("global_idx") << ";\n" + << "var input_indices: input_indices_t = input_indices_t(" << input_indices_init_value.str() << ");\n" + << loop_header << loop_body << loop_footer; + shader.MainFunctionBody() << output.SetByOffset("global_idx", "output_value"); + return Status::OK(); +} + +template +Status ReduceKernel::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + InlinedVector input_axes; + auto rank = input_tensor->Shape().NumDimensions(); + auto transform_axis = [rank](int64_t axis) { + if (axis < 0) { + axis += rank; + } + if (axis < 0 || static_cast(axis) >= rank) { + ORT_THROW("Axes values must be in the range [-rank, rank-1]. Got: ", axis); + } + return static_cast(axis); + }; + // Check if axes input is provided and copy the axes values to input_axes + if (context.InputCount() > 1) { + ORT_ENFORCE(axes_.empty(), "Axes attribute may not be specified when axes input is also provided."); + const Tensor* axes_tensor = context.Input(1); + auto size = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->Data(); + input_axes.reserve(size); + std::transform(data, data + size, std::back_inserter(input_axes), transform_axis); + } else { + input_axes.reserve(axes_.size()); + std::transform(axes_.begin(), axes_.end(), std::back_inserter(input_axes), transform_axis); + } + if (input_axes.empty()) { + if (noop_with_empty_axes_ || rank == 0) { + // If axes is empty and noop_with_empty_axes_ is true, it is a no-op according to the spec + // If input tensor is a scalar, return the input tensor as is. + // This is not correct for ReduceLogSum and ReduceSumSquare + // TODO handle these cases separately. + auto output = context.Output(0, input_tensor->Shape()); + if (output->DataRaw() != input_tensor->DataRaw()) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output)); + } + return Status::OK(); + } else { + // If axes is empty and noop_with_empty_axes_ is false, it is a reduction over all axes + input_axes.resize(rank); + std::iota(input_axes.begin(), input_axes.end(), 0); + } + } + const auto code = GetOpSpecificCode(input_tensor, input_axes.size()); + // Compute output shape + std::vector output_shape; + for (size_t i = 0; i < input_tensor->Shape().NumDimensions(); ++i) { + if (std::find(input_axes.begin(), input_axes.end(), i) != input_axes.end()) { + if (keepdims_) { + output_shape.push_back(1); + } + } else { + output_shape.push_back(input_tensor->Shape()[i]); + } + } + TensorShape output_tensor_shape(output_shape); + int64_t output_size = output_tensor_shape.Size(); + ReduceKernelProgram program("ReduceMean", keepdims_, noop_with_empty_axes_, input_axes, code); + program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) + .AddOutput({context.Output(0, output_shape), ProgramTensorMetadataDependency::TypeAndRank}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {static_cast(noop_with_empty_axes_ ? 1 : 0)}, + {input_axes}, + {static_cast(input_axes.size())}}); + + return context.RunProgram(program); +} + +ReduceOpSpecificCode ReduceMean::GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const { + const TensorShape& input_shape = input_tensor->Shape(); + size_t input_rank = input_shape.NumDimensions(); + std::stringstream ss; + ss << "var size: u32 = 1;\n" + << "for (var i: u32 = 0; i < uniforms.axes_size; i += 1) { \n" + << " let index = " << GetElementAt("uniforms.axes", "i", axes_size) << ";\n" + << " size = size * " << GetElementAt("uniforms.input_shape", "index", input_rank) << ";\n" + << "}\n" + << "let output_value = output_value_t(sum / f32(size));"; + ReduceOpSpecificCode code({"var sum = f32(0);", "sum += f32(current_element);", ss.str()}); + return code; +} + +Status ReduceMean::ComputeInternal(ComputeContext& ctx) const { + return ReduceKernel::ComputeInternal(ctx); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h new file mode 100644 index 0000000000000..e93eb06f20886 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/optional.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/reduction/reduction_kernel_base.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +namespace onnxruntime { +namespace webgpu { +// reduceOpSpecificCode is a 3-element array of strings that represent the op specific code for the reduce operation. +// The first element is the loop header, the second element is the loop body, and the third element is the loop footer. +// The loop header is the code that is executed before the loop starts. The loop body is the code that is executed for each element in the loop. +// The loop footer is the code that is executed after the loop ends. +typedef std::array ReduceOpSpecificCode; +class ReduceKernelProgram final : public Program { + public: + ReduceKernelProgram(std::string name, bool keepdims, bool no_op_with_empty_axes, const InlinedVector& axes, ReduceOpSpecificCode code) : Program{name}, keepdims_(keepdims), no_op_with_empty_axes_(no_op_with_empty_axes), axes_(axes.begin(), axes.end()), code_(code) {} + Status GenerateShaderCode(ShaderHelper& wgpuShaderModuleAddRef) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"no_op_with_empty_axes", ProgramUniformVariableDataType::Uint32}, + {"axes", ProgramUniformVariableDataType::Uint32}, + {"axes_size", ProgramUniformVariableDataType::Uint32}); + + private: + const bool keepdims_; + const bool no_op_with_empty_axes_; + InlinedVector axes_; + ReduceOpSpecificCode code_; +}; + +template +class ReduceKernel : public WebGpuKernel, public ReduceKernelBase { + protected: + using ReduceKernelBase::axes_; + using ReduceKernelBase::noop_with_empty_axes_; + using ReduceKernelBase::keepdims_; + using ReduceKernelBase::select_last_index_; + + ReduceKernel(const OpKernelInfo& info, std::string name, optional keepdims_override = {}) + : WebGpuKernel(info), + ReduceKernelBase(info, keepdims_override), + name_(name) { + } + Status ComputeInternal(ComputeContext& ctx) const; + virtual ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const = 0; + + private: + std::string name_; +}; + +class ReduceMean final : public ReduceKernel { + public: + ReduceMean(const OpKernelInfo& info) : ReduceKernel(info, "ReduceMean") {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const override; + Status ComputeInternal(ComputeContext& ctx) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 8fccbacac903b..19cab9b178b1f 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -345,9 +345,6 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha })) { ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); ss << "enable f16;\n"; - if (device_.HasFeature(wgpu::FeatureName::SubgroupsF16)) { - ss << "enable subgroups_f16;\n"; - } } if (device_.HasFeature(wgpu::FeatureName::Subgroups)) { ss << "enable subgroups;\n"; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 5e5920f582251..f8e1e0b3b8d2b 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -91,7 +91,7 @@ ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableD : name_(name), type_(type), num_components_{NumberOfComponents(type)}, - rank_{gsl::narrow(dims.NumDimensions())}, + rank_{static_cast(dims.NumDimensions())}, dims_{dims}, usage_(usage), indices_type_{GetIndicesType(rank_)}, diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 8b5bede34e6d0..7f92ea4ed3776 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -69,7 +69,7 @@ Status Cast::ComputeInternal(ComputeContext& context) const { if (size == 0) { return Status::OK(); } - uint32_t vec_size = gsl::narrow((size + 3) / 4); + uint32_t vec_size = onnxruntime::narrow((size + 3) / 4); CastProgram program{to_}; program diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.h b/onnxruntime/core/providers/webgpu/tensor/cast.h index ef5c4d5d0dabe..925cd200f0aba 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.h +++ b/onnxruntime/core/providers/webgpu/tensor/cast.h @@ -26,7 +26,7 @@ class Cast final : public WebGpuKernel { int64_t to; Status status = info.GetAttr("to", &to); ORT_ENFORCE(status.IsOK(), "Attribute to is not set."); - to_ = gsl::narrow(to); + to_ = onnxruntime::narrow(to); // ignore attribute 'saturate' as float8 is not supported in WebGPU } diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 5ed8099fde05e..5cfd6c78f8929 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -104,7 +104,7 @@ Status Concat::ComputeInternal(ComputeContext& context) const { return Status::OK(); } - uint32_t output_size = gsl::narrow_cast(prepare.output_tensor->Shape().Size()); + uint32_t output_size = onnxruntime::narrow(prepare.output_tensor->Shape().Size()); size_t axis = static_cast(prepare.axis); ConcatProgram program{axis}; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 809616660aa9e..9bdebe2c1e0d3 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -42,7 +42,7 @@ Status Expand::ComputeInternal(ComputeContext& context) const { : 1; const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4 : 1; - uint32_t data_size = gsl::narrow(output_shape.Size() / components_o); + uint32_t data_size = onnxruntime::narrow(output_shape.Size() / components_o); ExpandProgram program{}; program diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc index 9f6e5f2420d86..39d07991f3c5a 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -42,7 +42,7 @@ Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { Status Gather::ComputeInternal(ComputeContext& context) const { Prepare p; ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); - uint32_t data_size = gsl::narrow(p.output_tensor->Shape().Size()); + uint32_t data_size = onnxruntime::narrow(p.output_tensor->Shape().Size()); if (data_size == 0) { return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.cc b/onnxruntime/core/providers/webgpu/tensor/pad.cc new file mode 100644 index 0000000000000..6a8bc6554b772 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/pad.cc @@ -0,0 +1,261 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/util/math.h" +#include "core/providers/webgpu/tensor/pad.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +Status PadProgram::GenerateShaderCode(ShaderHelper& shader) const { + if (!dim_value_zero_) { + shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"); + std::string constant_value_str = std::string("let constant_value = ") + + (is_float16_ ? "bitcast>(uniforms.constant_value)[0];\n" : "bitcast(uniforms.constant_value);\n"); + if (dim_value_zero_) { + // Only Constant mode needs fill output if the one dim value or mores dims' values of input are zero. + shader.MainFunctionBody() << constant_value_str + << "output[global_idx] = constant_value;\n"; + return Status::OK(); + } + + shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " var input_index = u32(0);\n" + << " var use_pad_value = false;\n" + << " var in_coord = i32(0);\n"; + + const int rank = output.Rank(); + std::string output_indices_str = "i32(" + GetElementAt("output_indices", "dim", rank) + ")"; + std::string lower_pads_str = GetElementAt("uniforms.lower_pads", "dim", rank); + std::string data_shape_str = "i32(" + GetElementAt("uniforms.data_shape", "dim", rank) + ")"; + std::string data_stride_str = rank == 1 ? "" : " * " + GetElementAt("uniforms.data_stride", "dim", rank - 1); + std::string begin_axis_statement = "in_coord = "; + std::string end_axis_statement = "in_coord = "; + std::string in_axis_statement = "in_coord = " + output_indices_str + " - " + lower_pads_str + ";\n"; + switch (mode_) { + case Mode::Constant: + begin_axis_statement = "use_pad_value = true;\n"; + end_axis_statement = "use_pad_value = true;\n"; + break; + case Mode::Edge: + begin_axis_statement += "0;\n"; + end_axis_statement += data_shape_str + " - 1;\n"; + break; + case Mode::Reflect: + begin_axis_statement += lower_pads_str + " - " + output_indices_str + ";\n"; + end_axis_statement += data_shape_str + " - 2 - (" + output_indices_str + + " - (" + lower_pads_str + " + " + data_shape_str + "));\n"; + break; + case Mode::Wrap: + begin_axis_statement += data_shape_str + " + " + output_indices_str + " - " + lower_pads_str + ";\n"; + end_axis_statement += output_indices_str + " - " + lower_pads_str + " - " + data_shape_str + ";\n"; + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported mode type: ", static_cast(mode_)); + } + + shader.MainFunctionBody() << " for (var dim = 0; dim < " << rank << " && !use_pad_value; dim++) {\n" + << " if (" << output_indices_str << " < " << lower_pads_str << ") {\n" + << " " << begin_axis_statement << " }\n" + << " else if (" << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n" + << " " << end_axis_statement << " }\n" + << " else {\n" + << " " << in_axis_statement << " }\n" + << " input_index += select(u32(in_coord)" << data_stride_str << ", u32(in_coord), dim == " << rank - 1 << ");\n" + << " }\n" + << " " << constant_value_str + << " " << output.SetByOffset("global_idx", "select(data[input_index], constant_value, use_pad_value)"); + + return Status::OK(); +} + +Status Pad::ComputeInternal(ComputeContext& context) const { + const Tensor* input_tensor = context.Input(0); + auto const& input_shape = input_tensor->Shape(); + size_t dimension_count = input_shape.NumDimensions(); + + const PadsVector* p_pads = &pads_; + const PadsVector* p_slices = &slices_; + + PadsVector pads; + PadsVector slices; + // kOnnxDomain Pad opset >= 11 (Or) kMsDomain opset == 1 + if (is_dynamic_) { + size_t data_rank = input_tensor->Shape().NumDimensions(); + + const Tensor* pads_tensor = context.Input(1); + auto pads_tensor_dims = pads_tensor->Shape().GetDims(); + ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), + "Pads tensor should be a 1D tensor of shape [2 * num_axes] " + "or a 2D tensor of shape [1, 2 * num_axes]"); + + const auto pads_data = pads_tensor->DataAsSpan(); + + // Compute Pads by applying axes if specified otherwise copy the supplied pads. + PadBase::ComputePads(context.KernelContext(), data_rank, pads_data, pads); + + // Separate out any negative pads into the slices array + PadBase::SeparateNegativeToSlices(pads, slices); + + p_pads = &pads; + p_slices = &slices; + } + + auto output_dims(input_shape.AsShapeVector()); + ORT_ENFORCE(dimension_count * 2 == p_pads->size(), "'pads' attribute has wrong number of values"); + + // Calculate output dimensions, and handle any negative padding + std::vector lower_pads(dimension_count); + for (size_t i = 0; i < dimension_count; i++) { + int64_t lower_pad = (*p_pads)[i] + (*p_slices)[i]; + int64_t upper_pad = (*p_pads)[i + dimension_count] + (*p_slices)[i + dimension_count]; + lower_pads[i] = static_cast(lower_pad); + output_dims[i] += lower_pad + upper_pad; + } + TensorShape output_shape(output_dims); + + // special case when there is a dim value of 0 in the shape. behavior depends on mode + bool dim_value_zero = input_shape.Size() == 0; + if (dim_value_zero) { + ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode_, input_shape, output_shape)); + } + + auto* output_tensor = context.Output(0, output_shape); + uint32_t output_size = onnxruntime::narrow(output_shape.Size()); + if (output_size == 0) { + // Do not need to fill output, return + return Status::OK(); + } + + // Read constant value and bitcast to uint32. + uint32_t value_uint32 = 0; + const auto data_type = input_tensor->GetElementType(); + bool is_float16 = data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + const Tensor* value_tensor = context.Input(2); + if (!is_dynamic_) { + if (is_float16) { + uint16_t value = math::floatToHalf(value_); + std::memcpy(&value_uint32, &value, sizeof(value)); + } else { + value_uint32 = *reinterpret_cast(&value_); + } + } else if (value_tensor) { + ORT_ENFORCE(value_tensor->DataType() == input_tensor->DataType() && value_tensor->Shape().Size() == 1, + "Value tensor should be a 1D tensor of size 1 with the same type as that of the input tensor"); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + int32_t value = value_tensor->Data()[0]; + value_uint32 = *reinterpret_cast(&value); + } break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + float value = value_tensor->Data()[0]; + value_uint32 = *reinterpret_cast(&value); + } break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + uint16_t value = value_tensor->Data()[0].val; + std::memcpy(&value_uint32, &value, sizeof(value)); + } break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { + value_uint32 = value_tensor->Data()[0]; + } break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input type: ", static_cast(data_type)); + } + } + + PadProgram program{mode_, dim_value_zero, is_float16}; + if (!dim_value_zero) { + program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .CacheHint(std::to_string(static_cast(mode_)), dim_value_zero) + .AddUniformVariables({{gsl::span(lower_pads.data(), lower_pads.size())}, {output_size}, {value_uint32}}); + + return context.RunProgram(program); +} + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 2, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 13, 17, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 18, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_KERNEL_EX( + Pad, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.h b/onnxruntime/core/providers/webgpu/tensor/pad.h new file mode 100644 index 0000000000000..58049ddb0e5ce --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/pad.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/tensor/padbase.h" + +namespace onnxruntime { +namespace webgpu { + +class PadProgram final : public Program { + public: + PadProgram(const Mode mode, bool dim_value_zero, bool is_float16) : Program{"Pad"}, + mode_{mode}, + dim_value_zero_{dim_value_zero}, + is_float16_{is_float16} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"lower_pads", ProgramUniformVariableDataType::Int32}, + {"output_size", ProgramUniformVariableDataType::Uint32}, + {"constant_value", ProgramUniformVariableDataType::Uint32}); + + private: + Mode mode_; + bool dim_value_zero_; + bool is_float16_; +}; + +class Pad final : public PadBase, public WebGpuKernel { + public: + Pad(const OpKernelInfo& info) : PadBase(info), WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc b/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc index 455e7dc54bf1d..f68ace3c1d8a1 100644 --- a/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc +++ b/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc @@ -211,7 +211,7 @@ Status ResizeNearestImpl(ComputeContext& context, onnxruntime::ResizeNearestMode nearest_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = gsl::narrow(output_shape.Size()); + uint32_t output_size = onnxruntime::narrow(output_shape.Size()); ResizeNearestProgram program{coordinate_transform_mode, nearest_mode, extrapolation_enabled, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) @@ -299,7 +299,7 @@ Status ResizeBilinearImpl(ComputeContext& context, onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = gsl::narrow(output_shape.Size()); + uint32_t output_size = onnxruntime::narrow(output_shape.Size()); ResizeBilinearProgram program{coordinate_transform_mode, extrapolation_enabled, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) @@ -413,7 +413,7 @@ Status ResizeTrilinearImpl(ComputeContext& context, onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = gsl::narrow(output_shape.Size()); + uint32_t output_size = onnxruntime::narrow(output_shape.Size()); ResizeTrilinearProgram program{coordinate_transform_mode, extrapolation_enabled, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) @@ -534,7 +534,7 @@ Status ResizeBiCubicImpl(ComputeContext& context, onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = gsl::narrow(output_shape.Size()); + uint32_t output_size = onnxruntime::narrow(output_shape.Size()); ResizeBiCubicProgram program{coordinate_transform_mode, extrapolation_enabled, exclude_outside, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) diff --git a/onnxruntime/core/providers/webgpu/tensor/split.cc b/onnxruntime/core/providers/webgpu/tensor/split.cc index 83bf832cc5b11..d93b75fa21c16 100644 --- a/onnxruntime/core/providers/webgpu/tensor/split.cc +++ b/onnxruntime/core/providers/webgpu/tensor/split.cc @@ -107,7 +107,7 @@ Status Split::ComputeInternal(ComputeContext& context) const { ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes)); - SplitProgram program{gsl::narrow_cast(axis)}; + SplitProgram program{static_cast(axis)}; program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank}); auto output_dimensions = input_shape.AsShapeVector(); @@ -120,7 +120,7 @@ Status Split::ComputeInternal(ComputeContext& context) const { program.AddOutput({output, ProgramTensorMetadataDependency::Rank}); } - uint32_t input_size = gsl::narrow(input_shape.Size()); + uint32_t input_size = onnxruntime::narrow(input_shape.Size()); // Early return if the input tensor is empty. if (input_size == 0) { return Status::OK(); @@ -130,7 +130,7 @@ Status Split::ComputeInternal(ComputeContext& context) const { std::vector sizes_in_split_axis; // sizes_in_split_axis are the cumulative sizes of the splits in the split axis. for (auto split_size : split_sizes) { - previous_sum += gsl::narrow(split_size); + previous_sum += onnxruntime::narrow(split_size); sizes_in_split_axis.push_back(previous_sum); } diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index c40ec43dd0009..0df7d1ae9fa2f 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -47,7 +47,10 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); -auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { +auto SqueezeShape(const gsl::span& shape, + const gsl::span& adjusted_perm, + TensorShapeVector& new_shape, + TensorShapeVector& new_perm) { for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] != 1) { new_shape.push_back(shape[i]); @@ -97,26 +100,28 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status Transpose::ComputeInternal(ComputeContext& context) const { - const auto* input_tensor = context.Input(0); - const TensorShape& input_shape = input_tensor->Shape(); - int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); +Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, + gsl::span permutations, + const Tensor& input, Tensor& output) { + const auto& input_shape = input.Shape(); + const auto& input_dims = input_shape.GetDims(); + int32_t rank = static_cast(input_shape.NumDimensions()); TensorShapeVector output_dims(rank); - InlinedVector default_perm(rank); - const InlinedVector* p_perm = nullptr; - ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); - TensorShape output_shape(output_dims); - auto* output_tensor = context.Output(0, output_shape); - InlinedVector new_shape{}; - InlinedVector new_perm{}; - SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm); - const bool channels_last = new_perm == InlinedVector({2, 3, 1}); - const bool channels_first = new_perm == InlinedVector({3, 1, 2}); + for (int32_t i = 0; i < rank; i++) { + output_dims[i] = input_dims[permutations[i]]; + } + + TensorShapeVector new_shape{}; + TensorShapeVector new_perm{}; + SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm); + const bool channels_last = new_perm == TensorShapeVector({2, 3, 1}); + const bool channels_first = new_perm == TensorShapeVector({3, 1, 2}); const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; auto new_input_shape = input_shape; TensorShape new_output_shape(output_dims); + if (use_shared) { new_input_shape = channels_last ? TensorShape({new_shape[0], new_shape[1] * new_shape[2]}) @@ -126,16 +131,16 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); } - uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); - TransposeProgram program{*p_perm, use_shared}; + uint32_t output_size = onnxruntime::narrow(input_shape.Size()); + TransposeProgram program{permutations, use_shared}; + if (use_shared) { program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); } - program - .CacheHint(absl::StrJoin(*p_perm, "-")) - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) + .CacheHint(absl::StrJoin(permutations, "-")) + .AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) + .AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) .SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) .AddUniformVariables({ @@ -148,5 +153,20 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { return context.RunProgram(program); } +Status Transpose::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int32_t rank = static_cast(input_shape.NumDimensions()); + + TensorShapeVector output_dims(rank); + InlinedVector default_perm(rank); + const InlinedVector* p_perm = nullptr; + ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + + return DoTranspose(context, *p_perm, *input_tensor, *output_tensor); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index 7cf5c1fe0865d..b62a419fa12bc 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -16,6 +16,8 @@ class Transpose final : public WebGpuKernel, public TransposeBase { Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { } Status ComputeInternal(ComputeContext& context) const override; + static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output); + constexpr static uint32_t TILE_SIZE = 16; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index e8cdabb9dbe40..d7272ec525296 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -127,7 +127,7 @@ Status Where::ComputeInternal(ComputeContext& context) const { ORT_RETURN_IF_ERROR(ComputeOutputShape(cond_shape, x_shape, y_shape, output_shape)); auto* output_tensor = context.Output(0, output_shape); constexpr int component = 4; - uint32_t vec_size = gsl::narrow_cast((output_shape.Size() + 3) / component); + uint32_t vec_size = onnxruntime::narrow((output_shape.Size() + 3) / component); const auto is_broadcast = !(x_shape == y_shape && y_shape == cond_shape); WhereProgram program{is_broadcast}; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 163dd691b7f16..97144573dde2d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -134,6 +134,8 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device."); } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Context is created for: Instance=" << instance_.Get() << ", Device=" << device_.Get() << "."; + // cache adapter info ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_)); // cache device limits @@ -165,7 +167,6 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi #if defined(ENABLE_PIX_FOR_WEBGPU_EP) // set pix frame generator pix_frame_generator_ = std::make_unique(instance_, - Adapter(), Device()); #else ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)"); @@ -321,9 +322,9 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { std::vector dims(expected_rank); std::vector stride(expected_rank - 1); for (size_t j = 0; j < expected_rank; ++j) { - dims[j] = gsl::narrow(shape[j]); + dims[j] = onnxruntime::narrow(shape[j]); if (j < expected_rank - 1) { - stride[j] = gsl::narrow(shape.SizeFromDimension(j + 1)); + stride[j] = onnxruntime::narrow(shape.SizeFromDimension(j + 1)); } } @@ -490,8 +491,7 @@ std::vector WebGpuContext::GetAvailableRequiredFeatures(const #endif wgpu::FeatureName::TimestampQuery, wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::Subgroups, - wgpu::FeatureName::SubgroupsF16}; + wgpu::FeatureName::Subgroups}; for (auto feature : features) { if (adapter.HasFeature(feature)) { required_features.push_back(feature); @@ -708,45 +708,46 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co WGPUInstance instance = config.instance; WGPUDevice device = config.device; - if (context_id == 0) { - // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. - ORT_ENFORCE(instance == nullptr && device == nullptr, - "WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device."); - - std::call_once(init_default_flag_, [ + std::call_once(init_default_flag_, [ #if !defined(__wasm__) - dawn_proc_table = config.dawn_proc_table + dawn_proc_table = config.dawn_proc_table #endif - ]() { - // Step.1 - setup dawn proc table (only for non-WASM build) + ]() { + // Step.1 - setup dawn proc table (only for non-WASM build) #if !defined(__wasm__) - const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); + const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); #if defined(BUILD_DAWN_MONOLITHIC_LIBRARY) - ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); + ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); #else #if !defined(USE_EXTERNAL_DAWN) - if (dawn_procs == nullptr) { - dawn_procs = &dawn::native::GetProcs(); - } + if (dawn_procs == nullptr) { + dawn_procs = &dawn::native::GetProcs(); + } #else - ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); + ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); #endif - dawnProcSetProcs(dawn_procs); + dawnProcSetProcs(dawn_procs); #endif #endif - // Step.2 - Create wgpu::Instance + // Step.2 - Create wgpu::Instance #if !defined(__wasm__) - wgpu::InstanceDescriptor instance_desc{}; - instance_desc.capabilities.timedWaitAnyEnable = true; - default_instance_ = wgpu::CreateInstance(&instance_desc); + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.capabilities.timedWaitAnyEnable = true; + default_instance_ = wgpu::CreateInstance(&instance_desc); #else - default_instance_ = wgpu::CreateInstance(nullptr); + default_instance_ = wgpu::CreateInstance(nullptr); #endif - ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance."); - }); + ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance."); + }); + + if (context_id == 0) { + // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. + ORT_ENFORCE(instance == nullptr && device == nullptr, + "WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device."); + instance = default_instance_.Get(); } else { // for context ID > 0, user must provide custom WebGPU instance and device. @@ -800,5 +801,9 @@ void CleanupWebGpuContexts() { WebGpuContextFactory::Cleanup(); } +WGPUDevice GetDevice(int context_id) { + return WebGpuContextFactory::GetContext(context_id).Device().Get(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index d44cf4674d8a3..df7f2d6dcdeab 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -23,6 +23,7 @@ #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/external_data_loader.h" #include "core/providers/webgpu/webgpu_profiler.h" namespace onnxruntime { @@ -363,7 +364,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Pad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If); @@ -516,10 +519,10 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -625,9 +628,9 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -685,11 +688,13 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -760,6 +765,7 @@ std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { std::vector> WebGpuExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // `tenative_candidates` is a subset of `candidates`. @@ -821,6 +827,12 @@ std::unique_ptr WebGpuExecutionProvider::GetDataTran return std::make_unique(context_); } +#if defined(__wasm__) +std::unique_ptr WebGpuExecutionProvider::GetExternalDataLoader() const { + return std::make_unique(); +} +#endif + WebGpuExecutionProvider::~WebGpuExecutionProvider() { WebGpuContextFactory::ReleaseContext(context_id_); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 7a0ade97aa3df..e2e23b6a307cf 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -45,10 +45,14 @@ class WebGpuExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; +#if defined(__wasm__) + std::unique_ptr GetExternalDataLoader() const override; +#endif DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc index 90b99b7b38bb1..9b287b7b7df99 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc @@ -11,7 +11,7 @@ namespace onnxruntime { namespace webgpu { -WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Adapter adapter, wgpu::Device device) { +WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Device device) { // Trivial window size for surface texture creation and provide frame concept for PIX. static constexpr uint32_t kWidth = 512u; static constexpr uint32_t kHeight = 512u; @@ -32,7 +32,7 @@ WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu:: wgpu::TextureFormat format; wgpu::SurfaceCapabilities capabilities; - surface_.GetCapabilities(adapter, &capabilities); + surface_.GetCapabilities(device.GetAdapter(), &capabilities); format = capabilities.formats[0]; wgpu::SurfaceConfiguration config; diff --git a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h index 52a7459a81eba..0d9393321284d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h @@ -41,7 +41,7 @@ namespace webgpu { // WebGpuContext destruction. class WebGpuPIXFrameGenerator { public: - WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Adapter adapter, wgpu::Device device); + WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Device device); ~WebGpuPIXFrameGenerator(); void GeneratePIXFrame(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 60c61b2ca5665..1d779152f91f3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -151,6 +151,12 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( validation_mode, }; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Device ID: " << context_id; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUInstance: " << webgpu_instance; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUDevice: " << webgpu_device; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP DawnProcTable: " << dawn_proc_table; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP ValidationMode: " << validation_mode; + // // STEP.3 - prepare parameters for WebGPU context initialization. // diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index cbaff79f4fd4f..966deb14196dd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -219,9 +219,17 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build sign_buffer.set(0, -1.0f); sign_buffer.set(1, 1.0f); } else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - sign_buffer = emscripten::val::global("Uint16Array").new_(2); - sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f)); - sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f)); + if (model_builder.IsFloat16ArrayAvailable()) { + // Float16Array is avaliable - use Float16Array. + sign_buffer = emscripten::val::global("Float16Array").new_(2); + sign_buffer.set(0, -1.0f); + sign_buffer.set(1, 1.0f); + } else { + // Float16Array is not available - use Uint16Array instead. + sign_buffer = emscripten::val::global("Uint16Array").new_(2); + sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f)); + sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f)); + } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input data type: ", input_data_type); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index ace6519a1fc11..cf4ce216ed5b3 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -197,7 +197,8 @@ Status ModelBuilder::RegisterInitializers() { // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached // buffers in JS side. Simply create a copy to fix it. - operand = wnn_builder_.call("constant", desc, view.call("slice")); + view = view.call("slice"); + operand = wnn_builder_.call("constant", desc, view["buffer"]); } } else { // TODO: support other type. @@ -350,7 +351,8 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( emscripten::val operand = emscripten::val::object(); // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached // buffers in JS side. Simply create a copy to fix it. - operand = wnn_builder_.call("constant", desc, view.call("slice")); + view = view.call("slice"); + operand = wnn_builder_.call("constant", desc, view["buffer"]); AddOperand(name, operand); mem_persist_buffers_.push_back(std::move(persist_buffer)); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 4e2d84f481df0..1e5f859506d6b 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -30,6 +30,7 @@ class ModelBuilder { Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; // Accessors for members. + bool IsFloat16ArrayAvailable() const { return is_float16array_available_; } const GraphViewer& GetGraphViewer() const { return graph_viewer_; } InitializedTensorSet GetInitializerTensors(); @@ -68,6 +69,8 @@ class ModelBuilder { private: const GraphViewer& graph_viewer_; const logging::Logger& logger_; + const bool is_float16array_available_ = !emscripten::val::global("Float16Array").isUndefined() && + emscripten::val::global("Float16Array").hasOwnProperty("from"); emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); @@ -172,9 +175,12 @@ const emscripten::val& ModelBuilder::CreateOrGetConstant(const int32_t& data_typ } break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - buffer = emscripten::val::global("Uint16Array").new_(num_elements); + buffer = is_float16array_available_ + ? emscripten::val::global("Float16Array").new_(num_elements) + : emscripten::val::global("Uint16Array").new_(num_elements); if (value) { - buffer.call("fill", emscripten::val(PackFloat32ToUint16AsFloat16(value))); + buffer.call("fill", + emscripten::val(is_float16array_available_ ? value : PackFloat32ToUint16AsFloat16(value))); } break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 39e6520e3912b..7410ff66add30 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -56,6 +56,7 @@ WebNNExecutionProvider::~WebNNExecutionProvider() {} std::vector> WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_registries*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its // ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index e806dc340d53e..b8775e717668a 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -25,6 +25,7 @@ class WebNNExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_registries*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; DataLayout GetPreferredLayout() const override { return preferred_layout_; } diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 641f8b0729d0a..ab14c083884d3 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -258,6 +258,7 @@ static void AddComputeCapabilityForEachNodeInNodeUnit( std::vector> XnnpackExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { const auto& logger = *GetLogger(); std::vector> capabilities; diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h index 152bef1a1c52c..9c4d2484f9f4b 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h @@ -33,6 +33,7 @@ class XnnpackExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 7ef23d6c9e895..2e733f67a888c 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/graph/onnx_protobuf.h" -#include "core/common/inlined_containers.h" -#include "core/session/onnxruntime_c_api.h" -#include "core/session/ort_apis.h" -#include "core/framework/error_code_helper.h" -#include #include +#include #include + +#include "core/common/inlined_containers.h" +#include "core/framework/error_code_helper.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/abi_session_options_impl.h" #include "core/session/inference_session.h" -#include "abi_session_options_impl.h" -#include "api_utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" +#include "core/session/utils.h" OrtSessionOptions::~OrtSessionOptions() = default; diff --git a/onnxruntime/core/session/api_utils.cc b/onnxruntime/core/session/api_utils.cc deleted file mode 100644 index f7cb8520b1e5d..0000000000000 --- a/onnxruntime/core/session/api_utils.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "api_utils.h" - -onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { - const size_t str_len = str.size(); - const size_t req_size = str_len + 1; - - if (out == nullptr) { // User is querying the total output buffer size - *size = req_size; - return onnxruntime::common::Status::OK(); - } - - if (*size >= req_size) { // User provided a buffer of sufficient size - std::memcpy(out, str.data(), str_len); - out[str_len] = '\0'; - *size = req_size; - return onnxruntime::common::Status::OK(); - } - - // User has provided a buffer that is not large enough - *size = req_size; - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); -} diff --git a/onnxruntime/core/session/api_utils.h b/onnxruntime/core/session/api_utils.h deleted file mode 100644 index 27c2bbd66f8d5..0000000000000 --- a/onnxruntime/core/session/api_utils.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include - -onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size); diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 8492391172133..f583767346d88 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -20,7 +20,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "core/session/allocator_adapters.h" -#include "core/session/api_utils.h" +#include "core/session/utils.h" #include "core/session/custom_ops.h" #include "core/session/inference_session.h" #include "core/session/ort_apis.h" @@ -900,13 +900,14 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops) { // The function registers the first schema assuming all the other one are the same except the types constraints. ORT_ENFORCE(ops.size() > 0, "No kernels to registers."); - int undefined = 0; + int num_inputs_with_dynamic_type = 0; // Creation of the schema for the first kernel in ops. const OrtCustomOp* op = *ops.begin(); ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "custom op registered at runtime", 0); - auto create_type_constraint = [&ops, &schema, &undefined](const OrtCustomOp* op, int count, int i, bool is_input) { + auto create_type_constraint = [&ops, &schema, &num_inputs_with_dynamic_type]( + const OrtCustomOp* op, int count, int i, bool is_input) { onnx::OpSchema::FormalParameterOption option = onnx::OpSchema::FormalParameterOption::Single; bool is_homogeneous = true; int min_arity = 1; @@ -976,7 +977,9 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect } else { // all_types is empty. As mentioned in the previous loop, all types are allowed. schema.TypeConstraint(name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types"); - undefined++; + if (is_input) { + ++num_inputs_with_dynamic_type; + } } }; @@ -985,19 +988,21 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect create_type_constraint(op, static_cast(input_count), static_cast(i), true); } + const bool have_shape_infer_fn = op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn; + const size_t output_count = op->GetOutputTypeCount(op); for (size_t i = 0; i < output_count; i++) { const auto type = op->GetOutputType(op, i); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { if (op->GetOutputCharacteristic(op, i) == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED) { - ORT_ENFORCE(1 == undefined, - "There must be one (and only one) dynamic typed input to the custom op. " - "Its type info at runtime will be used to infer the type info of this dynamic typed output " - "which is required for the success of the model loading step. " - "More than one dynamic typed inputs are currently not supported as differing types at runtime " - "means the output type cannot be inferred without which model loading cannot proceed."); + // if there's a dynamically typed input and output we infer they both have the same type from the input. + // if that isn't the case the user must provide an output shape inference fn which must set the output type. + ORT_ENFORCE(num_inputs_with_dynamic_type == 1 || have_shape_infer_fn, + "The type of a dynamically typed output can be inferred from a single dynamically typed input, " + "or by a user provided OrtCustomOp->InferOutputShapeFn that sets the output type."); } } + create_type_constraint(op, static_cast(output_count), static_cast(i), false); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index a1903898ea7f0..e5ea562ce3535 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -38,9 +38,11 @@ #include "core/framework/utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/graph/model_editor_api_types.h" #include "core/graph/model_saving_options.h" #include "core/optimizer/graph_transformer_utils.h" #include "core/optimizer/graph_transformer.h" +#include "core/optimizer/graph_optimizer_registry.h" #include "core/optimizer/layout_transformation/layout_transformation.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" @@ -67,11 +69,11 @@ #include "core/optimizer/stft_decomposition.h" #endif #include "core/session/environment.h" -#include "core/session/user_logging_sink.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/session/user_logging_sink.h" #include "core/util/protobuf_parsing_utils.h" #include "core/util/thread_utils.h" @@ -1215,6 +1217,56 @@ common::Status InferenceSession::Load() { return LoadWithLoader(loader, "model_loading_from_saved_proto"); } +common::Status InferenceSession::Load(const OrtModel& model_editor_api_model) { + std::lock_guard l(session_mutex_); + + if (is_model_loaded_) { // already loaded + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + if (is_inited_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session has already been initialized."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; + + // need to go from unique_ptr to shared_ptr when moving into model_ + std::unique_ptr tmp_model; + ORT_RETURN_IF_ERROR(Model::LoadFromModelEditorApiModel(model_editor_api_model, + HasLocalSchema() ? &custom_schema_registries_ : nullptr, + ModelOptions(true, strict_shape_type_inference), + *session_logger_, tmp_model)); + + model_ = std::move(tmp_model); + + is_model_loaded_ = true; + + return Status::OK(); +} + +common::Status InferenceSession::ApplyUpdates(const OrtModel& model_editor_api_model) { + std::lock_guard l(session_mutex_); + + if (!is_model_loaded_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session does not contain a loaded model."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + if (is_inited_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session has already been initialized."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + return model_->MainGraph().UpdateUsingModelEditorApiModel(model_editor_api_model); +} + common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format) { // The transformer order: // 1. Ensure we inline as many functions as possible. We refer to it as Ahead Of Time (AOT) function inlining. @@ -1227,8 +1279,13 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // 6. insert cast nodes (required transformer). // 7. insert copy nodes (required transformer). + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&session_options_, + execution_providers_.Get(onnxruntime::kCpuExecutionProvider), + session_logger_); + GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry)); + // Run Ahead Of time function inlining - GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_); if (const bool disable_aot_function_inlining = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsDisableAheadOfTimeFunctionInlining, "0") == "1"; @@ -1631,7 +1688,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, SessionState& session_state, - const ConfigOptions& config_options, + const SessionOptions& sess_options, const logging::Logger& logger) { layout_transformation::TransformLayoutFunction transform_layout_fn = nullptr; @@ -1649,11 +1706,16 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - GraphPartitioner partitioner(kernel_registry_manager, providers); + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&sess_options, + providers.Get(onnxruntime::kCpuExecutionProvider), + &logger); + + GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry)); ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, - config_options, + sess_options.config_options, logger, GraphPartitioner::Mode::kOrtFormatLoad)); @@ -2096,7 +2158,7 @@ common::Status InferenceSession::Initialize() { #endif // !defined(ORT_MINIMAL_BUILD) } else { ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, - *session_state_, session_options_.config_options, *session_logger_)); + *session_state_, session_options_, *session_logger_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); @@ -3336,6 +3398,10 @@ common::Status InferenceSession::WaitForNotification(Notification* p_executor_do return Status::OK(); } +const Model& InferenceSession::GetModel() const { + return *model_; +} + SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) { ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK()); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 2c0c09dfd3e51..5b484103c9ecf 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -47,6 +47,9 @@ namespace ONNX_NAMESPACE { class ModelProto; } // namespace ONNX_NAMESPACE +// OrtModelEditorApi Model. Used to dynamically construct a model via C API at runtime. +struct OrtModel; + namespace onnxruntime { // forward declarations class CustomRegistry; class Environment; @@ -320,6 +323,27 @@ class InferenceSession { * @return OK if success. */ [[nodiscard]] common::Status Load(); + + /** + * Load an OrtModel that was dynamically constructed via OrtModelEditorApi. + * + * @param graph_api_model OrtModel from OrtModelEditorApi + * @return OK if success. + */ + [[nodiscard]] common::Status Load(const OrtModel& graph_api_model); + + /** + * Apply updates from an OrtModel that was created via OrtModelEditorApi. + * This can: + * - add nodes at the start and end of the model + * - add initializers + * - update the graph inputs/outputs + * + * @param graph_api_model OrtModel from OrtModelEditorApi + * @return OK if success. + */ + [[nodiscard]] common::Status ApplyUpdates(const OrtModel& graph_api_model); + #endif // !defined(ORT_MINIMAL_BUILD) /** @@ -571,6 +595,8 @@ class InferenceSession { #endif + const Model& GetModel() const; + protected: #if !defined(ORT_MINIMAL_BUILD) @@ -627,6 +653,12 @@ class InferenceSession { /// convenience pointer to logger. should always be the same as session_state_.Logger(); const logging::Logger* session_logger_; + // The list of execution providers. + // This MUST be prior to model_ in case there are values in the model that were allocated using an allocator + // provided by the EP. If that is the case the allocator's `free` implementation may depend on other parts of the + // EP instance. + ExecutionProviders execution_providers_; + // The model served by this inference session instance. // Currently this has to be a shared ptr because the Model::Load method // returns a shared_ptr only. Ideally factory functions should always return @@ -637,9 +669,6 @@ class InferenceSession { // The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx PathString model_location_; - // The list of execution providers. - ExecutionProviders execution_providers_; - private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); void SetLoggingManager(const SessionOptions& session_options, diff --git a/onnxruntime/core/session/model_editor_api.h b/onnxruntime/core/session/model_editor_api.h new file mode 100644 index 0000000000000..71004866bc867 --- /dev/null +++ b/onnxruntime/core/session/model_editor_api.h @@ -0,0 +1,65 @@ +namespace OrtModelEditorAPI { + +// implementation that returns the API struct +ORT_API(const OrtModelEditorApi*, GetModelEditorApi); + +// APIs to create/edit type info +ORT_API_STATUS_IMPL(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Out_ OrtTypeInfo** type_info); + +ORT_API_STATUS_IMPL(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info); + +ORT_API_STATUS_IMPL(CreateNode, const char* operator_name, const char* domain_name, _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _Inout_opt_ OrtOpAttr** attributes, _In_opt_ size_t attribs_len, + _Outptr_ OrtNode** node); + +ORT_API_STATUS_IMPL(CreateGraph, _Outptr_ OrtGraph** graph); +ORT_API_STATUS_IMPL(SetGraphInputs, _In_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); +ORT_API_STATUS_IMPL(SetGraphOutputs, _In_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); +ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor, + bool data_is_external); +ORT_API_STATUS_IMPL(AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node); + +ORT_API_STATUS_IMPL(CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model); +ORT_API_STATUS_IMPL(AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph); + +ORT_API_STATUS_IMPL(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + +// +// Model editing APIs for updating existing model by adding node/s at start or end. +// +ORT_API_STATUS_IMPL(CreateModelEditorSession, _In_ const OrtEnv* env, + _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + +ORT_API_STATUS_IMPL(CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + +ORT_API_STATUS_IMPL(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, + _Out_ int* opset); + +ORT_API_STATUS_IMPL(ApplyModelToModelEditorSession, _In_ OrtSession* session, _In_ OrtModel* model); + +ORT_API_STATUS_IMPL(FinalizeModelEditorSession, _In_ OrtSession* session, _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container); + +} // namespace OrtModelEditorAPI diff --git a/onnxruntime/core/session/model_editor_c_api.cc b/onnxruntime/core/session/model_editor_c_api.cc new file mode 100644 index 0000000000000..2f09b903ed941 --- /dev/null +++ b/onnxruntime/core/session/model_editor_c_api.cc @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "core/framework/error_code_helper.h" +#include "core/framework/ort_value.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" +#include "core/graph/model_editor_api_types.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/inference_session.h" +#include "core/session/model_editor_api.h" +#include "core/session/ort_apis.h" +#include "core/session/ort_env.h" +#include "core/session/utils.h" + +using namespace onnxruntime; + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info) { + API_IMPL_BEGIN + if (name == nullptr || *name == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "name cannot be null or empty string"); + } + + if (type_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_info cannot be null"); + } + + if (type_info->type != ONNX_TYPE_TENSOR) { + return OrtApis::CreateStatus(ORT_FAIL, "Only tensor types are supported currently"); + } + + if (type_info->tensor_type_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tensor_type_info cannot be null"); + } + + auto vi = std::make_unique(); + vi->name = name; + vi->type_info = type_info->Clone(); + + *value_info = vi.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateNode, const char* operator_name, const char* domain_name, + _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _Inout_opt_ OrtOpAttr** attributes, _In_opt_ size_t attribs_len, + _Outptr_ OrtNode** node) { + API_IMPL_BEGIN + auto n = std::make_unique(); + n->operator_name = operator_name; + n->domain_name = domain_name == kOnnxDomainAlias ? kOnnxDomain : domain_name; + n->node_name = node_name; + + n->input_names.reserve(input_names_len); + for (size_t i = 0; i < input_names_len; ++i) { + n->input_names.push_back(input_names[i]); + } + + n->output_names.reserve(output_names_len); + for (size_t i = 0; i < output_names_len; ++i) { + n->output_names.push_back(output_names[i]); + } + + if (attributes != nullptr) { + n->attributes.reserve(attribs_len); + for (size_t i = 0; i < attribs_len; ++i) { + n->attributes.push_back(*reinterpret_cast(attributes[i])); + // take ownership. as we took a copy that means releasing the original value + OrtApis::ReleaseOpAttr(attributes[i]); + attributes[i] = nullptr; + } + } + + *node = n.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateGraph, _Outptr_ OrtGraph** graph) { + API_IMPL_BEGIN + auto g = std::make_unique(); + + // do some reserves to reduce reallocation. if we had a hint about sizes upfront that would be optimal + g->initializers.reserve(32); + g->external_initializers.reserve(32); + g->nodes.reserve(64); + + *graph = g.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len) { + API_IMPL_BEGIN + graph->inputs.clear(); + for (size_t i = 0; i < inputs_len; ++i) { + if (inputs[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "inputs cannot contain null entries"); + } + + graph->inputs.push_back(std::unique_ptr(inputs[i])); // take ownership + inputs[i] = nullptr; + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len) { + API_IMPL_BEGIN + graph->outputs.clear(); + for (size_t i = 0; i < outputs_len; ++i) { + if (outputs[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "outputs cannot contain null entries"); + } + + graph->outputs.push_back(std::unique_ptr(outputs[i])); // take ownership + outputs[i] = nullptr; + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, + _Inout_ OrtValue* tensor, bool data_is_external) { + API_IMPL_BEGIN + if (!tensor->IsTensor()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only Tensor is currently supported."); + } + + if (!tensor->IsAllocated()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Tensor must be allocated."); + } + + const auto& t = tensor->Get(); + if (t.Location().device.Type() != OrtDevice::CPU) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only CPU based tensors are currently supported."); + } + + if (data_is_external) { + // enforce that an external initializer is not used if the data size is < 128 bytes. + // the reason for this is to avoid potential shape inferencing errors if this initializer is providing an + // input involved in that. the ONNX shape inferencing does not support external data for those values. + // e.g. Reshape's `shape` input, Reduce's `axes', Slice's `starts`, `ends`, `steps`, Clip's `min`, `max`, etc. + if (t.SizeInBytes() < 128) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "External initializer should only be used for data >= 128 bytes. " + "Please use CreateTensorAsOrtValue instead."); + } + + graph->external_initializers[name] = std::unique_ptr(tensor); // take ownership + } else { + graph->initializers[name] = std::unique_ptr(tensor); // take ownership + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node) { + API_IMPL_BEGIN + graph->nodes.push_back(std::unique_ptr(node)); // take ownership + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model) { + API_IMPL_BEGIN + auto m = std::make_unique(); + for (size_t i = 0; i < opset_entries_len; ++i) { + m->domain_to_version[domain_names[i]] = opset_versions[i]; + } + + *model = m.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph) { + API_IMPL_BEGIN + + if (graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + + model->graph = std::unique_ptr(graph); // take ownership + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { + API_IMPL_BEGIN + + std::unique_ptr sess; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment()); + + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(*model)); + + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + + *out = reinterpret_cast(sess.release()); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModelEditorSession, + _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out) { + API_IMPL_BEGIN + std::unique_ptr session; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, session)); + *out = reinterpret_cast(session.release()); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out) { + API_IMPL_BEGIN + std::unique_ptr session; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, session)); + *out = reinterpret_cast(session.release()); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::SessionGetOpsetForDomain, _In_ const OrtSession* ort_session, + _In_ const char* domain, _Out_ int* opset) { + const auto& session = *reinterpret_cast(ort_session); + const auto& domain_opset_map = session.GetModel().MainGraph().DomainToVersionMap(); + + auto it = domain_opset_map.find(domain); + if (it == domain_opset_map.cend()) { + return OrtApis::CreateStatus(ORT_FAIL, "Domain not used by model."); + } + + *opset = it->second; + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::ApplyModelToModelEditorSession, + _In_ OrtSession* session, _In_ OrtModel* model) { + API_IMPL_BEGIN + auto sess = reinterpret_cast(session); + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->ApplyUpdates(*model)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::FinalizeModelEditorSession, _In_ OrtSession* session, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container) { + API_IMPL_BEGIN + auto sess = reinterpret_cast(session); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); + return nullptr; + API_IMPL_END +} + +static constexpr OrtModelEditorApi ort_model_editor_api = { + // NOTE: The C# bindings depend on the API order within this struct so all additions must be at the end, + // and no functions can be removed (the implementation needs to change to return an error). + + &OrtModelEditorAPI::CreateTensorTypeInfo, + &OrtModelEditorAPI::CreateSparseTensorTypeInfo, + &OrtModelEditorAPI::CreateMapTypeInfo, + &OrtModelEditorAPI::CreateSequenceTypeInfo, + &OrtModelEditorAPI::CreateOptionalTypeInfo, + + &OrtModelEditorAPI::CreateValueInfo, + + &OrtModelEditorAPI::CreateNode, + + &OrtModelEditorAPI::CreateGraph, + &OrtModelEditorAPI::SetGraphInputs, + &OrtModelEditorAPI::SetGraphOutputs, + &OrtModelEditorAPI::AddInitializerToGraph, + &OrtModelEditorAPI::AddNodeToGraph, + + &OrtModelEditorAPI::CreateModel, + &OrtModelEditorAPI::AddGraphToModel, + + &OrtModelEditorAPI::CreateSessionFromModel, + + &OrtModelEditorAPI::CreateModelEditorSession, + &OrtModelEditorAPI::CreateModelEditorSessionFromArray, + &OrtModelEditorAPI::SessionGetOpsetForDomain, + &OrtModelEditorAPI::ApplyModelToModelEditorSession, + &OrtModelEditorAPI::FinalizeModelEditorSession, +}; + +// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned +static_assert(offsetof(OrtModelEditorApi, FinalizeModelEditorSession) / sizeof(void*) == 19, + "Size of version 21 API cannot change"); // initial version in ORT 1.21 + +ORT_API(const OrtModelEditorApi*, OrtModelEditorAPI::GetModelEditorApi) { + return &ort_model_editor_api; +} + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4eedcd591154f..0e23d7a791bec 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1,45 +1,47 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/onnxruntime_c_api.h" -#include "core/session/allocator_adapters.h" -#include "core/session/inference_session_utils.h" -#include "core/session/IOBinding.h" -#include "core/framework/allocator.h" -#include "core/framework/error_code_helper.h" -#include "core/framework/execution_provider.h" -#include "core/framework/tensor_type_and_shape.h" -#include "core/framework/utils.h" #include #include #include +#include #include #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" -#include "core/common/status.h" #include "core/common/safeint.h" -#include "core/graph/constants.h" -#include "core/graph/graph.h" +#include "core/common/status.h" +#include "core/common/string_helper.h" #include "core/framework/allocator.h" -#include "core/framework/tensor.h" +#include "core/framework/allocator.h" +#include "core/framework/callback.h" +#include "core/framework/data_types.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/execution_provider.h" +#include "core/framework/onnxruntime_typeinfo.h" #include "core/framework/ort_value.h" +#include "core/framework/tensor.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/TensorSeq.h" +#include "core/framework/utils.h" +#include "core/graph/constants.h" +#include "core/graph/graph.h" +#include "core/graph/model_editor_api_types.h" #include "core/providers/get_execution_providers.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/allocator_adapters.h" #include "core/session/environment.h" -#include "core/framework/callback.h" -#include "core/framework/tensorprotoutils.h" -#include "core/framework/onnxruntime_typeinfo.h" #include "core/session/inference_session.h" +#include "core/session/inference_session_utils.h" +#include "core/session/IOBinding.h" +#include "core/session/lora_adapters.h" +#include "core/session/model_editor_api.h" +#include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" -#include "core/framework/data_types.h" -#include "abi_session_options_impl.h" -#include "core/framework/TensorSeq.h" -#include -#include "core/common/string_helper.h" - -#include "core/session/lora_adapters.h" +#include "core/session/utils.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" @@ -114,6 +116,72 @@ using namespace onnxruntime; auto v = (value); \ auto tensor = v->GetMutable(); +namespace { +// Create tensor. Allocates memory. Tensor owns memory. Allocator is wrapped and stored in a shared_ptr in Tensor. +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, + OrtAllocator* allocator, OrtValue& value) { + TensorShape tensor_shape(shape, shape_len); + AllocatorPtr alloc_ptr = std::make_shared(allocator); + Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value); + return nullptr; +} + +// Create Tensor with existing data. Tensor does not own memory. +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, + const int64_t* shape, size_t shape_len, + const OrtMemoryInfo* info, + void* p_data, size_t p_data_len, + OrtValue& ort_value) { + TensorShape tensor_shape(shape, shape_len); + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); + } + + size_t size_to_allocate = 0; + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); + if (!status.IsOK()) { + return ToOrtStatus(status); + } + if (size_to_allocate > p_data_len) { + std::ostringstream oss; + oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + + Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value); + return nullptr; +} + +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, + const int64_t* shape, size_t shape_len, + OrtAllocator* deleter, + void* p_data, size_t p_data_len, + OrtValue& ort_value) { + TensorShape tensor_shape(shape, shape_len); + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); + } + + size_t size_to_allocate = 0; + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); + + if (!status.IsOK()) { + return ToOrtStatus(status); + } + + if (size_to_allocate > p_data_len) { + std::ostringstream oss; + oss << "p_data_len was smaller than expected. Expected:" << size_to_allocate << " Got:" << p_data_len; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + + AllocatorPtr alloc_ptr = std::make_shared(deleter); + Tensor::InitOrtValue(ml_type, tensor_shape, p_data, std::move(alloc_ptr), ort_value); + return nullptr; +} + +} // namespace + ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out) { @@ -187,50 +255,6 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env, API_IMPL_END } -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, - _Inout_ OrtAllocator* allocator, OrtValue& value) { - TensorShape tensor_shape(shape, shape_len); - AllocatorPtr alloc_ptr = std::make_shared(allocator); - Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value); - return nullptr; -} - -ORT_STATUS_PTR CreateTensorImplForSeq(MLDataType elem_type, const int64_t* shape, size_t shape_len, Tensor& out) { - OrtAllocator* allocator; - // TODO(pranav): what allocator should be used to create the tensor here? - // for the sake of simplicity of the API using the default one here - ORT_API_RETURN_IF_ERROR(OrtApis::GetAllocatorWithDefaultOptions(&allocator)); - AllocatorPtr alloc_ptr = std::make_shared(allocator); - TensorShape tensor_shape(shape, shape_len); - out = Tensor(elem_type, tensor_shape, std::move(alloc_ptr)); - return nullptr; -} - -/** - * - * this function will create a copy of the allocator info - */ -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, const OrtMemoryInfo* info, - void* p_data, size_t p_data_len, OrtValue& ort_value) { - TensorShape tensor_shape(shape, shape_len); - if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); - } - - size_t size_to_allocate = 0; - Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); - if (!status.IsOK()) { - return ToOrtStatus(status); - } - if (size_to_allocate > p_data_len) { - std::ostringstream oss; - oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len; - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); - } - Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value); - return nullptr; -} - ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { @@ -243,6 +267,20 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemor API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out) { + API_IMPL_BEGIN + auto ml_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType(); + auto value = std::make_unique(); + ORT_API_RETURN_IF_ERROR(CreateTensorImpl(ml_type, shape, shape_len, deleter, p_data, p_data_len, *value)); + *out = value.release(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { @@ -678,97 +716,6 @@ ORT_API_STATUS_IMPL(OrtApis::EnableOrtCustomOps, _Inout_ OrtSessionOptions* opti API_IMPL_END } -namespace { -// provider either model_path, or modal_data + model_data_length. -static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, - _In_ const OrtEnv* env, - _In_opt_z_ const ORTCHAR_T* model_path, - _In_opt_ const void* model_data, - size_t model_data_length, - std::unique_ptr& sess) { - // quick check here to decide load path. InferenceSession will provide error message for invalid values. - // TODO: Could move to a helper - const Env& os_env = Env::Default(); // OS environment (!= ORT environment) - bool load_config_from_model = - os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1"; - - if (load_config_from_model) { -#if !defined(ORT_MINIMAL_BUILD) - if (model_path != nullptr) { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), - model_path); - } else { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), - model_data, static_cast(model_data_length)); - } -#else - return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build."); -#endif - } else { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment()); - } - -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - // Add custom domains - if (options && !options->custom_op_domains_.empty()) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); - } -#endif - - // Finish load - if (load_config_from_model) { -#if !defined(ORT_MINIMAL_BUILD) - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load()); -#endif - } else { - if (model_path != nullptr) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path)); - } else { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast(model_data_length))); - } - } - - return nullptr; -} - -static ORT_STATUS_PTR InitializeSession(_In_ const OrtSessionOptions* options, - _In_ std::unique_ptr<::onnxruntime::InferenceSession>& sess, - _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr) { - // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of - // byte addressable memory - std::vector> provider_list; - if (options) { - for (auto& factory : options->provider_factories) { - auto provider = factory->CreateProvider(); - provider_list.push_back(std::move(provider)); - } - } - - // register the providers - for (auto& provider : provider_list) { - if (provider) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->RegisterExecutionProvider(std::move(provider))); - } - } - - if (prepacked_weights_container != nullptr) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddPrePackedWeightsContainer( - reinterpret_cast(prepacked_weights_container))); - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Initialize()); - - return nullptr; -} - -} // namespace - ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN @@ -778,7 +725,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); *out = reinterpret_cast(sess.release()); } @@ -801,7 +748,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); *out = reinterpret_cast(sess.release()); } @@ -1208,7 +1155,6 @@ ORT_API_STATUS_IMPL(OrtApis::GetResizedStringTensorElementBuffer, _Inout_ OrtVal } namespace { - OrtStatusPtr GetTensorStringSpan(const ::OrtValue& v, gsl::span& span) { if (!v.IsAllocated()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue should contain a Tensor or a Sparse Tensor"); @@ -2112,7 +2058,6 @@ ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, _In_ const char* domain_name, _In_ } namespace { - struct ProviderBuffer { char** buffer_; char* next_write_; @@ -2342,7 +2287,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionWithPrepackedWeightsContainer, _In_ co ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess, prepacked_weights_container)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); *out = reinterpret_cast(sess.release()); } @@ -2368,7 +2313,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess, prepacked_weights_container)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); *out = reinterpret_cast(sess.release()); } @@ -2410,6 +2355,39 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSes API_IMPL_END } +ORT_API(void, OrtApis::ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo* value_info) { + delete value_info; +} + +ORT_API(void, OrtApis::ReleaseNode, _Frees_ptr_opt_ OrtNode* node) { + delete node; +} + +ORT_API(void, OrtApis::ReleaseGraph, _Frees_ptr_opt_ OrtGraph* graph) { + delete graph; +} + +ORT_API(void, OrtApis::ReleaseModel, _Frees_ptr_opt_ OrtModel* model) { + delete model; +} + +ORT_API_STATUS_IMPL(OrtApis::GetValueInfoName, _In_ const OrtValueInfo* value_info, + _Out_ const char** name) { + API_IMPL_BEGIN + *name = value_info->name.c_str(); + return nullptr; + API_IMPL_END +} +ORT_API_STATUS_IMPL(OrtApis::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, + _Outptr_ const OrtTypeInfo** type_info) { + API_IMPL_BEGIN + + *type_info = value_info->type_info.get(); + + return nullptr; + API_IMPL_END +} + ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #ifdef ENABLE_TRAINING_APIS if (version >= 13 && version <= ORT_API_VERSION) @@ -2419,13 +2397,21 @@ ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { version, ORT_API_VERSION); return nullptr; #else - ORT_UNUSED_PARAMETER(version); return nullptr; #endif } +ORT_API(const OrtModelEditorApi*, OrtApis::GetModelEditorApi) { +#if !defined(ORT_MINIMAL_BUILD) + return OrtModelEditorAPI::GetModelEditorApi(); +#else + fprintf(stderr, "The Model Editor API is not supported in a minimal build.\n"); + return nullptr; +#endif +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -2812,6 +2798,18 @@ static constexpr OrtApi ort_api_1_to_22 = { &OrtApis::SetEpDynamicOptions, // End of Version 20 - DO NOT MODIFY ABOVE (see above text for more information) + + &OrtApis::ReleaseValueInfo, + &OrtApis::ReleaseNode, + &OrtApis::ReleaseGraph, + &OrtApis::ReleaseModel, + + &OrtApis::GetValueInfoName, + &OrtApis::GetValueInfoTypeInfo, + + &OrtApis::GetModelEditorApi, + + &OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 52d3c98d526dc..9d8aeb18a782f 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -20,6 +20,10 @@ ORT_API(void, ReleaseCustomOpDomain, _Frees_ptr_opt_ OrtCustomOpDomain*); ORT_API(void, ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo*); ORT_API(void, ReleaseSequenceTypeInfo, _Frees_ptr_opt_ OrtSequenceTypeInfo*); ORT_API(void, ReleaseModelMetadata, _Frees_ptr_opt_ OrtModelMetadata*); +ORT_API(void, ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo*); +ORT_API(void, ReleaseNode, _Frees_ptr_opt_ OrtNode*); +ORT_API(void, ReleaseGraph, _Frees_ptr_opt_ OrtGraph*); +ORT_API(void, ReleaseModel, _Frees_ptr_opt_ OrtModel*); _Check_return_ _Ret_notnull_ [[nodiscard]] OrtStatus* ORT_API_CALL CreateStatus(OrtErrorCode code, _In_z_ const char* msg) NO_EXCEPTION; @@ -533,4 +537,16 @@ ORT_API_STATUS_IMPL(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* optio ORT_API_STATUS_IMPL(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); + +ORT_API_STATUS_IMPL(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); +ORT_API_STATUS_IMPL(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); + +ORT_API(const OrtModelEditorApi*, GetModelEditorApi); + +ORT_API_STATUS_IMPL(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 77c6d4c371f69..2ea4a93d21f2e 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -4,6 +4,7 @@ // This is the Onnxruntime side of the bridge to allow providers to be built as a DLL // It implements onnxruntime::ProviderHost +#include #include "core/common/inlined_containers.h" #include "core/common/path_string.h" #include "core/framework/allocator_utils.h" @@ -35,6 +36,7 @@ #include "core/graph/graph_proto_serializer.h" #include "core/framework/murmurhash3.h" #include "core/framework/model_metadef_id_generator.h" +#include "core/optimizer/graph_optimizer_registry.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" @@ -237,6 +239,21 @@ common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_na struct ProviderHostImpl : ProviderHost { const OrtApiBase* OrtGetApiBase() override { return ::OrtGetApiBase(); } + Status GetOptimizerByName(const std::string& name, + const GraphOptimizerRegistry& graph_optimizer_registry, + SelectionFunc& selection_func) override { + std::string optimizer_name(name); + + auto func = graph_optimizer_registry.GetSelectionFunc(optimizer_name); + + if (func.has_value()) { + selection_func = func.value(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get optimizer " + optimizer_name); + } + return Status::OK(); + }; + void* HeapAllocate(size_t size) override { return new uint8_t[size]; } void HeapFree(void* p) override { delete[] reinterpret_cast(p); } @@ -360,8 +377,9 @@ struct ProviderHostImpl : ProviderHost { std::vector> IExecutionProvider__GetCapability( const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, const IExecutionProvider::IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) override { - return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup, resource_accountant); + return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, resource_accountant); } common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override { @@ -797,6 +815,8 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) override { return std::make_unique(std::move(t_sub_graph)); } void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; } std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability* p) override { return p->sub_graph; } + void ComputeCapability__copy_optimization_func(ComputeCapability* p, ComputeCapability* selection_cc) override { p->optimization_func = selection_cc->optimization_func; } + void ComputeCapability__add_nodes_to_optimize(ComputeCapability* p, std::unique_ptr optimization_cc) override { p->nodes_to_optimize.push_back(std::move(optimization_cc)); } // DataTransferManager (wrapped) Status DataTransferManager__CopyTensor(const DataTransferManager* p, const Tensor& src, Tensor& dst) override { return p->CopyTensor(src, dst); } @@ -1631,6 +1651,7 @@ struct ProviderHostImpl : ProviderHost { Status LoadDynamicLibrary(onnxruntime::PathString library_name) override { return LoadDynamicLibraryFromProvider(library_name); }; #endif } provider_host_; + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc new file mode 100644 index 0000000000000..afb1ed2696c9f --- /dev/null +++ b/onnxruntime/core/session/utils.cc @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/utils.h" + +#include "core/framework/error_code_helper.h" +#include "core/framework/execution_provider.h" +#include "core/session/abi_session_options_impl.h" +// #include "core/session/environment.h" +#include "core/session/inference_session.h" +#include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" +#include "core/session/ort_env.h" + +using namespace onnxruntime; + +common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { + const size_t str_len = str.size(); + const size_t req_size = str_len + 1; + + if (out == nullptr) { // User is querying the total output buffer size + *size = req_size; + return onnxruntime::common::Status::OK(); + } + + if (*size >= req_size) { // User provided a buffer of sufficient size + std::memcpy(out, str.data(), str_len); + out[str_len] = '\0'; + *size = req_size; + return onnxruntime::common::Status::OK(); + } + + // User has provided a buffer that is not large enough + *size = req_size; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); +} + +// provider either model_path, or modal_data + model_data_length. +OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess) { + // quick check here to decide load path. InferenceSession will provide error message for invalid values. + // TODO: Could move to a helper + const Env& os_env = Env::Default(); // OS environment (!= ORT environment) + bool load_config_from_model = + os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1"; + + if (load_config_from_model) { +#if !defined(ORT_MINIMAL_BUILD) + if (model_path != nullptr) { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment(), + model_path); + } else { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment(), + model_data, static_cast(model_data_length)); + } +#else + return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build."); +#endif + } else { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment()); + } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + // Add custom domains + if (options && !options->custom_op_domains_.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); + } +#endif + + // Finish load + if (load_config_from_model) { +#if !defined(ORT_MINIMAL_BUILD) + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load()); +#endif + } else { + if (model_path != nullptr) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path)); + } else { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast(model_data_length))); + } + } + + return nullptr; +} + +OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, + _In_ onnxruntime::InferenceSession& sess, + _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) { + // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of + // byte addressable memory + std::vector> provider_list; + if (options) { + for (auto& factory : options->provider_factories) { + auto provider = factory->CreateProvider(); + provider_list.push_back(std::move(provider)); + } + } + + // register the providers + for (auto& provider : provider_list) { + if (provider) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider))); + } + } + + if (prepacked_weights_container != nullptr) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.AddPrePackedWeightsContainer( + reinterpret_cast(prepacked_weights_container))); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); + + return nullptr; +} diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h new file mode 100644 index 0000000000000..ac8ad60758b5b --- /dev/null +++ b/onnxruntime/core/session/utils.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/common.h" +#include "core/session/onnxruntime_c_api.h" + +onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size); + +struct OrtSessionOptions; +struct OrtStatus; +struct OrtPrepackedWeightsContainer; +namespace onnxruntime { +class InferenceSession; +} + +OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess); + +OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, + _In_ onnxruntime::InferenceSession& sess, + _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index ea995d4707ba3..50da0025752aa 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -204,9 +204,9 @@ def get_qnn_qdq_config( calibrate_method=calibrate_method, activation_type=activation_type, weight_type=weight_type, - op_types_to_quantize=op_types_to_quantize - if op_types_to_quantize - else list(op_types.difference(OP_TYPES_TO_EXCLUDE)), + op_types_to_quantize=( + op_types_to_quantize if op_types_to_quantize else list(op_types.difference(OP_TYPES_TO_EXCLUDE)) + ), nodes_to_exclude=nodes_to_exclude, per_channel=per_channel, use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index fa468a9676a65..d19bebad8a12c 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -240,6 +240,8 @@ def get_qdq_config( keep_removable_activations: bool = False, min_real_range: float | None = None, tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None, + calibration_providers: list[str] | None = None, + op_types_to_quantize: list[str] | None = None, nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None, extra_options: dict | None = None, ) -> StaticQuantConfig: @@ -294,6 +296,10 @@ def get_qdq_config( 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation, other nodes get the original type. If not specified, assume all consumer nodes get the converted type. + calibration_providers: Execution providers to run the session during calibration. Default is None which uses + [ "CPUExecutionProvider" ]. + op_types_to_quantize: List of operator types to quantize. If None, all operators other than Cast, DequantizeLinear, + and QuantizeLinear are quantized. nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto should be excluded from quantization. @@ -324,17 +330,20 @@ def get_qdq_config( if onnx.external_data_helper.uses_external_data(initializer): model_has_external_data = True - final_nodes_to_exclude = [] - if nodes_to_exclude is not None and isinstance(nodes_to_exclude, list): - final_nodes_to_exclude.extend(nodes_to_exclude) + op_types_to_quantize_set = set(op_types_to_quantize) if op_types_to_quantize else None + nodes_to_exclude_set = set(nodes_to_exclude) if isinstance(nodes_to_exclude, list) else set() # Iterate through nodes to get all operator types in the model and # call user's function to filter out nodes from quantization. for node in model.graph.node: - op_types.add(node.op_type) - if nodes_to_exclude is not None and callable(nodes_to_exclude): - if nodes_to_exclude(model, node): - final_nodes_to_exclude.append(node.name) + if op_types_to_quantize_set and node.op_type not in op_types_to_quantize_set: + continue + if node.name in nodes_to_exclude_set: + continue + if callable(nodes_to_exclude) and nodes_to_exclude(model, node): + nodes_to_exclude_set.add(node.name) + else: + op_types.add(node.op_type) final_extra_options = { "MinimumRealRange": min_real_range, @@ -378,11 +387,14 @@ def get_qdq_config( quant_format=QuantFormat.QDQ, activation_type=activation_type, weight_type=weight_type, - op_types_to_quantize=list(op_types.difference(op_types_to_exclude)), - nodes_to_exclude=final_nodes_to_exclude, + op_types_to_quantize=( + op_types_to_quantize if op_types_to_quantize else list(op_types.difference(op_types_to_exclude)) + ), + nodes_to_exclude=list(nodes_to_exclude_set), per_channel=per_channel, reduce_range=reduce_range, use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), + calibration_providers=calibration_providers, extra_options=final_extra_options, ) @@ -442,7 +454,7 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua if activation_type != QuantType.QFLOAT8E4M3FN and weight_type == QuantType.QFLOAT8E4M3FN: raise ValueError( f"ONNXRuntime quantization doesn't support data format: activation_type={activation_type} " - f"!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN." + "!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN." ) if activation_type == QuantType.QFLOAT8E4M3FN and weight_type != QuantType.QFLOAT8E4M3FN: diff --git a/onnxruntime/python/tools/transformers/models/sam2/README.md b/onnxruntime/python/tools/transformers/models/sam2/README.md index e7cafeffc6231..463d154525f8f 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/README.md +++ b/onnxruntime/python/tools/transformers/models/sam2/README.md @@ -96,8 +96,7 @@ We can create a conda environment then run GPU benchmark like the following: conda create -n sam2_gpu python=3.11 -y conda activate sam2_gpu install_dir=$HOME -profiling=true -bash benchmark_sam2.sh $install_dir gpu $profiling +bash benchmark_sam2.sh $install_dir gpu ``` or create a new conda environment for CPU benchmark: @@ -107,16 +106,28 @@ conda activate sam2_cpu bash benchmark_sam2.sh $HOME cpu ``` -The first parameter is a directory to clone git repositories or install CUDA/cuDNN for benchmark. -The second parameter can be either "gpu" or "cpu", which indicates the device to run benchmark. -The third parameter is optional. Value "true" will enable profiling after running benchmarking on GPU. +The usage of the script like the following: +``` +bash benchmark_sam2.sh [profiling] [benchmarking] [nightly] [dynamo] +``` + +| Parameter| Default | Description | +|----------|----------| ------------| +| install_dir | $HOME | a directory to clone git repositories or install CUDA/cuDNN for benchmark | +| cpu_or_gpu | gpu | the device to run benchmark. The value can be either "gpu" or "cpu" | +| profiling | false | run gpu profiling | +| benchmarking | true | run benchmark | +| nightly | false | install onnxruntime nightly or official release package | +| dynamo | false | export image encoder using dynamo or not. | -The script will automatically install required packages in current conda environment, download checkpoints, export onnx, -and run demo, benchmark and optionally run profiling. +The dynamo export is experimental since graph optimization still need extra works for this model. -* The performance test result is in sam2_gpu.csv or sam2_cpu.csv, which can be loaded into Excel. -* The demo output is sam2_demo_fp16_gpu.png or sam2_demo_fp32_cpu.png. -* The profiling results are in *.nsys-rep or *.json files in current directory. Use Nvidia NSight System to view the *.nsys-rep file. +Output files: +* sam2_cpu_[timestamp].csv or sam2_gpu_[timestamp].csv has benchmark results. Use Excel to load the file to view it. +* onnxruntime_image_[encoder|decoder].json has ONNX Runtime profiling results. Use `chrome://tracing` in Chrome browser to view it. +* torch_image_[encoder|decoder].json has PyTorch profiling results. Use `chrome://tracing` in Chrome browser to view it. +* sam2_fp16_profile_image_[encoder|decoder]_[ort|torch]_gpu.[nsys-rep|sqlite] has NVTX profiling. Use Nvidia NSight System to view it. +* torch_image_encoder_compiled_code.txt has the compiled kernel code from Pytorch. ## Limitations - The exported image_decoder model does not support batch mode for now. diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py index 16d71d5057b02..3fc24d157b0cf 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py @@ -46,6 +46,7 @@ def __init__( prefer_nhwc: bool = False, warm_up: int = 5, enable_nvtx_profile: bool = False, + enable_ort_profile: bool = False, enable_torch_profile: bool = False, repeats: int = 1000, verbose: bool = False, @@ -74,6 +75,7 @@ def __init__( self.prefer_nhwc = prefer_nhwc self.warm_up = warm_up self.enable_nvtx_profile = enable_nvtx_profile + self.enable_ort_profile = enable_ort_profile self.enable_torch_profile = enable_torch_profile self.repeats = repeats self.verbose = verbose @@ -317,6 +319,7 @@ def run_test( repeats=args.repeats, warm_up=args.warm_up, enable_nvtx_profile=args.enable_nvtx_profile, + enable_ort_profile=args.enable_ort_profile, enable_torch_profile=args.enable_torch_profile, torch_compile_mode=args.torch_compile_mode, verbose=False, @@ -325,7 +328,7 @@ def run_test( if args.engine == "ort": sess_options = SessionOptions() sess_options.intra_op_num_threads = args.intra_op_num_threads - if config.enable_nvtx_profile: + if config.enable_ort_profile: sess_options.enable_profiling = True sess_options.log_severity_level = 4 sess_options.log_verbosity_level = 0 @@ -349,6 +352,8 @@ def run_test( with nvtx.annotate("one_run"): _ = session.infer(input_dict) cudart.cudaProfilerStop() + + if config.enable_ort_profile: session.ort_session.end_profiling() if repeats == 0: @@ -554,6 +559,14 @@ def _parse_arguments(): help="Enable nvtx profiling. It will add an extra run for profiling before performance test.", ) + parser.add_argument( + "--enable_ort_profile", + required=False, + default=False, + action="store_true", + help="Enable ORT profiling.", + ) + parser.add_argument( "--enable_torch_profile", required=False, diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh index 9e97867657ab9..c82b1ed31796e 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh @@ -5,7 +5,17 @@ # ------------------------------------------------------------------------- # Please refer to README.md for the prerequisites and usage of this script. -# bash benchmark_sam2.sh [profiling] +# bash benchmark_sam2.sh [profiling] [benchmarking] [nightly] [dynamo] +# Note that dynamo need onnxruntime 1.21 or later, or nightly build. +# Example: +# bash benchmark_sam2.sh $HOME gpu true true true false + +install_dir="${1:-$HOME}" +cpu_or_gpu="${2:-gpu}" +profiling="${3:-false}" +benchmarking="${4:-true}" +nightly="${5:-false}" +dynamo="${6:-false}" python="$CONDA_PREFIX/bin/python3" @@ -13,9 +23,6 @@ python="$CONDA_PREFIX/bin/python3" dir="$(cd "$(dirname "$0")" && pwd)" onnx_dir="$dir/sam2_onnx_models" -# Installation directory (default: $HOME) -install_dir="${1:-$HOME}" - if [ ! -d "$install_dir" ]; then echo "Error: install_dir '$install_dir' does not exist." exit 1 @@ -26,7 +33,6 @@ sam2_dir="$install_dir/segment-anything-2" model="sam2_hiera_large" # Default to GPU, switch to CPU if specified -cpu_or_gpu="${2:-gpu}" if [ "$cpu_or_gpu" != "gpu" ] && [ "$cpu_or_gpu" != "cpu" ]; then echo "Invalid option: $2. Please specify 'cpu' or 'gpu'." exit 1 @@ -35,52 +41,97 @@ fi echo "install_dir: $install_dir" echo "cpu_or_gpu: $cpu_or_gpu" -install_cuda_12() -{ - pushd $install_dir - wget https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run - sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath=$install_dir/cuda12.6 --silent --override --no-man-page +# Function to check if a command exists +command_exists() { + command -v "$1" >/dev/null 2>&1 +} + +# Ensure necessary tools are installed +if ! command_exists wget; then + echo "wget is not installed. Please install it and try again." + exit 1 +fi + +if ! command_exists git; then + echo "git is not installed. Please install it and try again." + exit 1 +fi + +if ! command_exists pip; then + echo "pip is not installed. Please install it and try again." + exit 1 +fi + +cuda_version=12.6 +cudnn_version=9.5 - export PATH="$install_dir/cuda12.6/bin:$PATH" - export LD_LIBRARY_PATH="$install_dir/cuda12.6/lib64:$LD_LIBRARY_PATH" - popd +# Install CUDA 12.6 +install_cuda_12() { + if ! [ -d "$install_dir/cuda${cuda_version}" ]; then + pushd "$install_dir" || exit + wget https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run + sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath="$install_dir/cuda${cuda_version}" --silent --override --no-man-page + popd || exit + fi + export PATH="$install_dir/cuda${cuda_version}/bin:$PATH" + export LD_LIBRARY_PATH="$install_dir/cuda${cuda_version}/lib64:$LD_LIBRARY_PATH" } -# Function to install cuDNN 9.4 +# Install cuDNN 9.5 install_cudnn_9() { - pushd "$install_dir" - wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz - mkdir -p "$install_dir/cudnn9.5" - tar -Jxvf cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz -C "$install_dir/cudnn9.5" --strip=1 - export LD_LIBRARY_PATH="$install_dir/cudnn9.5/lib:$LD_LIBRARY_PATH" - popd + if ! [ -d "$install_dir/cudnn${cudnn_version}" ]; then + pushd "$install_dir" || exit + wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz + mkdir -p "$install_dir/cudnn${cudnn_version}" + tar -Jxvf cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz -C "$install_dir/cudnn${cudnn_version}" --strip=1 + popd || exit + fi + export LD_LIBRARY_PATH="$install_dir/cudnn${cudnn_version}/lib:$LD_LIBRARY_PATH" +} + +install_ort() { + local ort="$1" + pip uninstall onnxruntime onnxruntime-gpu -y + + if [ "$nightly" = "true" ]; then + pip install flatbuffers numpy packaging protobuf sympy + pip install --pre --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ "$ort" + else + pip install "$ort" + fi + + pip install onnx onnxscript opencv-python matplotlib } # Install GPU dependencies install_gpu() { - [ ! -d "$install_dir/cuda12.6" ] && install_cuda_12 - [ ! -d "$install_dir/cudnn9.5" ] && install_cudnn_9 + install_cuda_12 + install_cudnn_9 + echo "PATH: $PATH" + echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" + + # The dynamo export need torch 2.6.0 or later. Use the latest one. + pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 --upgrade - pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 - pip install onnxruntime-gpu onnx opencv-python matplotlib + install_ort "onnxruntime-gpu" } # Install CPU dependencies install_cpu() { pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - pip install onnxruntime onnx opencv-python matplotlib + install_ort "onnxruntime" } # Clone and install SAM2 if not already installed install_sam2() { - pushd "$install_dir" + pushd "$install_dir" || exit if [ ! -d "$sam2_dir" ]; then git clone https://github.com/facebookresearch/segment-anything-2.git fi - cd "$sam2_dir" + cd "$sam2_dir" || exit pip show SAM-2 > /dev/null 2>&1 || pip install -e . [ ! -f checkpoints/sam2_hiera_large.pt ] && (cd checkpoints && sh ./download_ckpts.sh) - popd + popd || exit } # Download test image if not available @@ -90,7 +141,12 @@ download_test_image() { run_cpu_benchmark() { local repeats="$1" - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --demo + + if [ "$dynamo" = "true" ]; then + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --demo --dynamo + else + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --demo + fi for component in image_encoder image_decoder; do $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --dtype fp32 --component "$component" @@ -103,65 +159,75 @@ run_cpu_benchmark() { done } -run_gpu_benchmark() { +run_ort_gpu_benchmark() { local repeats="$1" - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp32 - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp16 --demo - for component in image_encoder image_decoder; do - for dtype in bf16 fp32 fp16; do - $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" - done - done + if [ "$dynamo" = "true" ]; then + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp32 --dynamo + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp16 --demo --dynamo + else + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp32 + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp16 --demo + fi component="image_encoder" for dtype in fp32 fp16; do - #TODO: --prefer_nhwc does not help with performance - $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --use_cuda_graph + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype "$dtype" --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --use_cuda_graph done + # Test prefer_nhwc. + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --use_cuda_graph --prefer_nhwc component="image_decoder" for dtype in fp32 fp16; do # TODO: decoder does not work with cuda graph - $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype "$dtype" --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" done + # Test prefer_nhwc. + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --prefer_nhwc } -run_torch_compile_gpu_benchmark() { +run_torch_gpu_benchmark() { local repeats="$1" + # Test PyTorch eager mode. + for component in image_encoder image_decoder; do + for dtype in bf16 fp32 fp16; do + $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype "$dtype" --component "$component" + done + done + # Test different torch compile modes on image encoder for torch_compile_mode in none max-autotune reduce-overhead max-autotune-no-cudagraphs do - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component image_encoder --torch_compile_mode $torch_compile_mode + $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component image_encoder --torch_compile_mode $torch_compile_mode done } - -# Main script -run_benchmarks() { - if [ ! -v CONDA_PREFIX ]; then - echo "Please activate conda environment before running this script." - exit 1 +install_all() { + if [ "$cpu_or_gpu" = "gpu" ]; then + install_gpu + else + install_cpu fi - - # Install dependencies - [ "$cpu_or_gpu" = "gpu" ] && install_gpu || install_cpu install_sam2 download_test_image +} - # Run benchmarks - output_csv="sam2_${cpu_or_gpu}.csv" +run_benchmarks() { + suffix=$(date +"%Y_%m_%d_%H_%M_%S") + [ "$dynamo" = "true" ] && suffix="${suffix}_dynamo" + output_csv="sam2_${cpu_or_gpu}_${suffix}.csv" if [ ! -f "$output_csv" ]; then echo "Running $cpu_or_gpu benchmark..." if [ "$cpu_or_gpu" = "gpu" ]; then - run_gpu_benchmark 1000 - run_torch_compile_gpu_benchmark 1000 + run_ort_gpu_benchmark 1000 + run_torch_gpu_benchmark 1000 else run_cpu_benchmark 100 fi cat benchmark*.csv > combined_csv awk '!x[$0]++' combined_csv > "$output_csv" + rm benchmark*.csv rm combined_csv echo "Benchmark results saved in $output_csv" else @@ -169,7 +235,16 @@ run_benchmarks() { fi } -run_benchmarks +if [ ! -v CONDA_PREFIX ]; then + echo "Please activate conda environment before running this script." + exit 1 +fi + +install_all + +if [ "$benchmarking" = "true" ]; then + run_benchmarks +fi #-------------------------------------------------------------------------- # Below are for profiling @@ -177,79 +252,100 @@ run_benchmarks # Build onnxruntime-gpu from source for profiling build_onnxruntime_gpu_for_profiling() { - pushd "$install_dir" + pushd "$install_dir" || exit if ! [ -d onnxruntime ]; then git clone https://github.com/microsoft/onnxruntime fi - cd onnxruntime - CUDA_ARCH=$(python3 -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}{cc[1]}')") - if [ -n "$CUDA_ARCH" ]; then - pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy==1.26.4 - sh build.sh --config Release --build_dir build/cuda12 --build_shared_lib --parallel \ - --use_cuda --cuda_version 12.6 --cuda_home $install_dir/cuda12.6 \ - --cudnn_home $install_dir/cudnn9.5 \ - --build_wheel --skip_tests \ - --cmake_generator Ninja \ - --compile_no_warning_as_error \ - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=$CUDA_ARCH \ - --cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \ - --enable_cuda_line_info - - pip install build/cuda12/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl numpy==1.26.4 - else - echo "No CUDA device found." - exit 1 - fi - popd + cd onnxruntime || exit + pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy + build_dir=build/cuda${cuda_version} + rm -rf ${build_dir}/Release/dist + sh build.sh --config Release --build_dir "${build_dir}" --build_shared_lib --parallel \ + --use_cuda --cuda_version ${cuda_version} --cuda_home "$install_dir/cuda${cuda_version}" \ + --cudnn_home "$install_dir/cudnn${cudnn_version}" \ + --build_wheel --skip_tests \ + --cmake_generator Ninja \ + --compile_no_warning_as_error \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=native \ + --cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \ + --enable_cuda_line_info + pip uninstall onnxruntime-gpu -y + pip install "${build_dir}/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl" + popd || exit } # Run profiling with NVTX. -run_nvtx_profile() -{ - pip install nvtx cuda-python==12.6.0 - +run_nvtx_profile() { + local engine="$1" # Only trace one device to avoid huge output file size. device_id=0 - envs="CUDA_VISIBLE_DEVICES=$device_id,ORT_ENABLE_CUDNN_FLASH_ATTENTION=1,LD_LIBRARY_PATH=$LD_LIBRARY_PATH" + envs="CUDA_VISIBLE_DEVICES=$device_id,ORT_ENABLE_CUDNN_FLASH_ATTENTION=1,LD_LIBRARY_PATH=$LD_LIBRARY_PATH,TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1" cuda_graph_trace=node - for engine in ort torch; do - for component in image_encoder image_decoder; do - sudo $install_dir/cuda12.6/bin/nsys profile --capture-range=nvtx --nvtx-capture='one_run' \ - --gpu-metrics-device $device_id --force-overwrite true \ - --sample process-tree --backtrace fp --stats true \ - -t cuda,cudnn,cublas,osrt,nvtx --cuda-memory-usage true --cudabacktrace all \ - --cuda-graph-trace $cuda_graph_trace \ - -e $envs,NSYS_NVTX_PROFILER_REGISTER_ONLY=0 \ - -o sam2_fp16_profile_${component}_${engine}_${cpu_or_gpu} \ - $python benchmark_sam2.py --model_type $model --engine $engine \ - --sam2_dir $sam2_dir --warm_up 1 --repeats 0 \ - --onnx_path ${onnx_dir}/${model}_${component}_fp16_gpu.onnx \ - --component $component \ - --use_gpu --dtype fp16 --enable_nvtx_profile - done + for component in image_encoder image_decoder; do + sudo "$install_dir/cuda${cuda_version}/bin/nsys" profile --capture-range=nvtx --nvtx-capture='one_run' \ + --gpu-metrics-devices $device_id --force-overwrite true \ + --sample process-tree --backtrace fp --stats true \ + -t cuda,cudnn,cublas,osrt,nvtx --cuda-memory-usage true --cudabacktrace all \ + --cuda-graph-trace "$cuda_graph_trace" \ + -e "$envs,NSYS_NVTX_PROFILER_REGISTER_ONLY=0" \ + -o "sam2_fp16_profile_${component}_${engine}_${cpu_or_gpu}" \ + $python benchmark_sam2.py --model_type "$model" --engine "$engine" \ + --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ + --onnx_path "${onnx_dir}/${model}_${component}_fp16_gpu.onnx" \ + --component "$component" \ + --use_gpu --dtype fp16 --enable_nvtx_profile done } -# Run profiling with PyTorch -run_torch_profile() { +run_ort_profile() { + export ORT_ENABLE_CUDNN_FLASH_ATTENTION=1 + rm -f onnxruntime_*.json for component in image_encoder image_decoder; do - $python benchmark_sam2.py --model_type $model --engine torch \ - --sam2_dir $sam2_dir --warm_up 1 --repeats 0 \ - --component $component \ - --use_gpu --dtype fp16 --enable_torch_profile + $python benchmark_sam2.py --model_type "$model" --engine ort \ + --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ + --onnx_path "${onnx_dir}/${model}_${component}_fp16_gpu.onnx" \ + --component "$component" \ + --use_gpu --dtype fp16 --enable_ort_profile + mv onnxruntime_profile*.json onnxruntime_$component.json done } -run_profilings() { - build_onnxruntime_gpu_for_profiling +# Run profiling with PyTorch +run_torch_profile() { + # Enable logging might could help get the code of compiled kernels. You can turn it off to reduce overhead. + export TORCH_LOGS="+inductor,+output_code" + export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 + component=image_encoder + $python benchmark_sam2.py --model_type "$model" --engine torch \ + --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ + --component "$component" \ + --torch_compile_mode max-autotune \ + --use_gpu --dtype fp16 --enable_torch_profile > "torch_${component}_compiled_code.txt" + + component=image_decoder + $python benchmark_sam2.py --model_type "$model" --engine torch \ + --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ + --component "$component" \ + --torch_compile_mode none \ + --use_gpu --dtype fp16 --enable_torch_profile +} +run_nvtx_profilings() { + build_onnxruntime_gpu_for_profiling rm -f *.nsys-rep *.sqlite - run_nvtx_profile + run_nvtx_profile ort + run_nvtx_profile torch +} +run_profilings() { + pip install nvtx cuda-python==${cuda_version}.0 + run_ort_profile run_torch_profile + + # NVTX profiling need to build onnxruntime-gpu from source so it is put as the last step. + run_nvtx_profilings } -profiling="${3:-false}" if [ "$profiling" = "true" ] && [ "$cpu_or_gpu" = "gpu" ]; then run_profilings fi diff --git a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py index cacad717faf9c..3533a274b9972 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py @@ -113,6 +113,14 @@ def parse_arguments(): help="Optimize onnx models for GPU", ) + parser.add_argument( + "--dynamo", + required=False, + default=False, + action="store_true", + help="Use dynamo for exporting onnx model. Only image_encoder supports dynamo right now.", + ) + parser.add_argument( "--verbose", required=False, @@ -151,8 +159,10 @@ def main(): onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output) if component == "image_encoder": if args.overwrite or not os.path.exists(onnx_model_path): - export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) - test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=False) + export_image_encoder_onnx( + sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose, args.dynamo + ) + test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=args.dynamic_batch_axes) elif component == "mask_decoder": if args.overwrite or not os.path.exists(onnx_model_path): diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py index 07ed150631f50..376e6ba7d802c 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py @@ -246,7 +246,7 @@ def test_decoder_onnx( import onnxruntime - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py index c5ce339732063..79e9297788c36 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py @@ -90,6 +90,8 @@ def export_image_encoder_onnx( onnx_model_path: str, dynamic_batch_axes: bool = False, verbose: bool = False, + dynamo: bool = False, + clear_dynamo_metadata: bool = False, ): image = random_sam2_input_image() @@ -113,17 +115,65 @@ def export_image_encoder_onnx( if not verbose: warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=UserWarning) - torch.onnx.export( - sam2_encoder, - image, - onnx_model_path, - export_params=True, - opset_version=17, - do_constant_folding=True, - input_names=["image"], - output_names=["image_features_0", "image_features_1", "image_embeddings"], - dynamic_axes=dynamic_axes, - ) + + if not dynamo: + torch.onnx.export( + sam2_encoder, + image, + onnx_model_path, + export_params=True, + opset_version=17, + do_constant_folding=True, + input_names=["image"], + output_names=["image_features_0", "image_features_1", "image_embeddings"], + dynamic_axes=dynamic_axes, + ) + else: + torch._dynamo.config.capture_scalar_outputs = True + ep = torch.export.export( + sam2_encoder, + args=(image,), + strict=False, + dynamic_shapes=[ + {0: torch.export.Dim.AUTO}, + ], + ) + + onnx_program = torch.onnx.export( + ep, + (), + opset_version=17, + input_names=["image"], + output_names=["image_features_0", "image_features_1", "image_embeddings"], + dynamo=True, + ) + onnx_program.optimize() + onnx_program.save(onnx_model_path + ".dynamo.onnx", external_data=False) + import onnx + + from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper + + onnx_model = onnx.load_model(onnx_model_path + ".dynamo.onnx", load_external_data=True) + if dynamic_batch_axes: + # Fix labels of dynamic axes since they can't be specified during Dynamo export currently + onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "batch_size" + for i in range(3): + onnx_model.graph.output[i].type.tensor_type.shape.dim[0].dim_param = "batch_size" + + onnx_model_helper = DynamoOnnxHelper(onnx_model) + onnx_model_helper.convert_constants_to_initializers() + if clear_dynamo_metadata: + onnx_model_helper.clear_metadata() + + import os + + if os.path.exists(onnx_model_path): + os.remove(onnx_model_path) + if os.path.exists(onnx_model_path + ".data"): + os.remove(onnx_model_path + ".data") + onnx_model_helper.model.save_model_to_file( + onnx_model_path, use_external_data_format=True, all_tensors_to_one_file=True, convert_attribute=True + ) print("encoder onnx model saved to", onnx_model_path) @@ -133,7 +183,7 @@ def test_image_encoder_onnx( onnx_model_path: str, dynamic_batch_axes=False, ): - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py b/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py index 56473c002d4ae..fa83e2f666d06 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py @@ -177,7 +177,7 @@ def test_mask_decoder_onnx( import onnxruntime - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py b/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py index 883c51858346c..f25e6ff23324b 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py @@ -146,7 +146,7 @@ def test_prompt_encoder_onnx( import onnxruntime - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/test/qnn_ctx_gen/README.md b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md similarity index 82% rename from onnxruntime/test/qnn_ctx_gen/README.md rename to onnxruntime/test/ep_weight_sharing_ctx_gen/README.md index 97ab89d79cbd2..be1a1fe039366 100644 --- a/onnxruntime/test/qnn_ctx_gen/README.md +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md @@ -2,17 +2,19 @@ This tool provides the way to generate Onnx models that wraps QNN context binary warpt with weight sharing enabled. The options to use with the tool are listed below: -`onnxruntime_qnn_ctx_gen [options...] model_path,model_path` +`ep_weight_sharing_ctx_gen [options...] model_1_path,model_2_path` -./onnxruntime_qnn_ctx_gen -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" -C "ep.context_enable|1 ep.context_embed_mode|0" /mnt/c/model1.onnx,/mnt/c/model2.onnx +./ep_weight_sharing_ctx_gen -e qnn -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" /mnt/c/model1.onnx,/mnt/c/model2.onnx Options: - + + -e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider qnn, tensorrt, openvino, vitisai. Default is qnn. + -v: Show verbose information. -C: [session_config_entries]: Specify session configuration entries as key-value pairs: -C "| |" Refer to onnxruntime_session_options_config_keys.h for valid keys and values. - [Example] -C "ep.context_enable|1 ep.context_embed_mode|0" + [Example] -C "ep.context_enable|1 ep.context_embed_mode|0". These are set as default so can be ignored. -i: [provider_options]: Specify QNN EP specific runtime options as key value pairs. Different runtime options available are: [Usage]: -i '| |' diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc similarity index 68% rename from onnxruntime/test/qnn_ctx_gen/command_args_parser.cc rename to onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc index 24c343c7b9541..bf21d54ccde41 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc @@ -1,5 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "command_args_parser.h" @@ -29,28 +28,30 @@ namespace qnnctxgen { /*static*/ void CommandLineParser::ShowUsage() { printf( - "onnxruntime_qnn_ctx_gen [options...] model1_path,model2_path\n" - "Example: ./onnxruntime_qnn_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n" + "ep_weight_sharing_ctx_gen [options...] model1_path,model2_path\n" + "Example: ./ep_weight_sharing_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n" "Options:\n" + "\t-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn','tensorrt','openvino', 'vitisai'. " + "Default:'qnn'.\n" "\t-v: Show verbose information.\n" "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" "\t Force ep.context_enable to 1 and ep.context_embed_mode to 0. Change ep.context_file_path is not allowed." "\t [Example] -C \"ep.context_node_name_prefix|_part1\" \n" - "\t-i: Specify QNN EP specific runtime options as key value pairs. Different runtime options available are: \n" + "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [Usage]: -i '| |'\n" "\n" - "\t [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. default to HTP backend\n" - "\t [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" - "\t [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: '0', '1', '2', '3', default is '0'.\n" - "\t [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" - "\t [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" - "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" + "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. default to HTP backend\n" + "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" + "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: '0', '1', '2', '3', default is '0'.\n" + "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" - "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" - "\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" - "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" - "\t [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." + "\t [QNN only] [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" + "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '1' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -109,8 +110,22 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(TestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("o:u:i:C:vh"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("e:o:u:i:C:vh"))) != -1) { switch (ch) { + case 'e': + if (!CompareCString(optarg, ORT_TSTR("qnn"))) { + test_config.machine_config.provider_type_name = onnxruntime::kQnnExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("openvino"))) { + test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("tensorrt"))) { + test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("vitisai"))) { + test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider; + } else { + fprintf(stderr, "The execution provider is not included in this tool.\n"); + return false; + } + break; case 'v': test_config.run_config.f_verbose = true; break; @@ -162,7 +177,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])"); } - test_config.run_config.qnn_options[key] = value; + test_config.run_config.provider_options[key] = value; } break; } diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.h b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.h similarity index 100% rename from onnxruntime/test/qnn_ctx_gen/command_args_parser.h rename to onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.h diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc new file mode 100644 index 0000000000000..104cdbdfd5abc --- /dev/null +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc @@ -0,0 +1,247 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_configuration.h" +#include "command_args_parser.h" + +// onnxruntime dependencies +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +// onnx dependencies +#include "onnx/onnx_pb.h" +#include + +using namespace onnxruntime; +using ProviderOptions = std::unordered_map; + +// from the last context cache Onnx model, find the EPContext node with main_context=1, +// and get the QNN context binary file name, this context binary contains all graphs from all Onnx models +// get the max spill fill buffer size +static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, + std::string& last_ctx_bin_file, + int64_t& max_size) { + max_size = 0; + + onnx::ModelProto model; + std::ifstream onnx_file_stream(last_onnx_ctx_file, std::ios_base::binary); + model.ParseFromIstream(&onnx_file_stream); + + for (auto& node : model.graph().node()) { + if (node.op_type() == "EPContext") { + int64_t is_main_context = 0; + for (auto& attr : node.attribute()) { + if (attr.name() == "main_context") { + is_main_context = attr.i(); + } + if (attr.name() == "max_size") { + max_size = attr.i(); + } + if (attr.name() == "ep_cache_context") { + last_ctx_bin_file = attr.s(); + } + } + if (is_main_context) { + return; + } + } + } + + onnx_file_stream.close(); +} + +// Update generated context cache Onnx model to make the main EPContext node point to +// the last QNN context binary file +// Remove not used QNN context binary file, only keep the last one which contains all graphs +static void UpdateEpContextModel(const std::vector>& ep_ctx_files, + const std::string& last_qnn_ctx_binary_file_name, + int64_t max_size) { + for (auto ep_ctx_file : ep_ctx_files) { + onnx::ModelProto model; + std::ifstream onnx_file_stream(ep_ctx_file, std::ios_base::binary); + model.ParseFromIstream(&onnx_file_stream); + onnx_file_stream.close(); + + for (auto& node : *(model.mutable_graph()->mutable_node())) { + if (node.op_type() == "EPContext") { + int64_t is_main_context = 0; + std::string old_qnn_ctx_binary_file_name; + int max_size_index = 0; + int ep_context_index = 0; + for (auto i = 0; i < node.attribute_size(); ++i) { + auto& attr = node.attribute()[i]; + if (attr.name() == "main_context") { + is_main_context = attr.i(); + } + if (attr.name() == "max_size") { + max_size = attr.i(); + max_size_index = i; + } + if (attr.name() == "ep_cache_context") { + old_qnn_ctx_binary_file_name = attr.s(); + ep_context_index = 0; + } + } + if (is_main_context) { + auto path_str = ToPathString(ep_ctx_file); + auto path = std::filesystem::path(path_str); + auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); + std::remove(file_path.string().c_str()); + + node.mutable_attribute(max_size_index)->set_i(max_size); + node.mutable_attribute(ep_context_index)->set_s(last_qnn_ctx_binary_file_name); + } + } + } + + // re-write the onnx ctx file + std::ofstream onnx_file_ostream(ep_ctx_file, std::ios_base::binary); + model.SerializeToOstream(&onnx_file_ostream); + onnx_file_ostream.close(); + } +} + +#ifdef _WIN32 +int real_main(int argc, wchar_t* argv[]) { +#else +int real_main(int argc, char* argv[]) { +#endif + qnnctxgen::TestConfig test_config; + if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { + qnnctxgen::CommandLineParser::ShowUsage(); + return -1; + } + + OrtLoggingLevel logging_level = test_config.run_config.f_verbose + ? ORT_LOGGING_LEVEL_VERBOSE + : ORT_LOGGING_LEVEL_ERROR; + Ort::Env env(logging_level, "ep_weight_sharing"); + + ORT_TRY { + Ort::SessionOptions so; + so.SetLogId("ep_weight_sharing_ctx_gen_session_logger"); + // Set default session option to dump EPContext model with non-embed mode + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + + ProviderOptions provider_options; + + for (auto it : test_config.run_config.provider_options) { + provider_options[it.first] = it.second; + } + + for (auto it : test_config.run_config.session_config_entries) { + if (it.first == kOrtSessionOptionEpContextEnable && it.second != "1") { + std::cerr << "Need to enable ep context cache." << std::endl; + continue; + } + if (it.first == kOrtSessionOptionEpContextEmbedMode && it.second != "0") { + std::cerr << "Only support non-embed model for weight sharing." << std::endl; + continue; + } + if (it.first == kOrtSessionOptionEpContextFilePath) { + std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; + continue; + } + so.AddConfigEntry(it.first.c_str(), it.second.c_str()); + } + + for (auto model_path : test_config.model_file_paths) { + std::cout << "Model file path: " << ToUTF8String(model_path) << std::endl; + } + + // Generate context cache model files with QNN context binary files + // The context binary file generated later includes all graphs from previous models + { + std::string provider_name_ = test_config.machine_config.provider_type_name; + if (provider_name_ == onnxruntime::kQnnExecutionProvider) { +#ifdef USE_QNN +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + // set default QNN EP option to enable weight sharing if not set by user + const std::string enable_htp_weight_sharing = "enable_htp_weight_sharing"; + if (provider_options.find(enable_htp_weight_sharing) == provider_options.end()) { + provider_options[enable_htp_weight_sharing] = "1"; + } + so.AppendExecutionProvider("QNN", provider_options); +#else + ORT_THROW("QNN is not supported in this build\n"); +#endif + } else if (!provider_name_.empty()) { + ORT_THROW("This execution provider is not included in this tool.\n"); + } + + size_t total_file_count = test_config.model_file_paths.size(); + for (size_t i = 0; i < total_file_count; ++i) { + auto model_path = test_config.model_file_paths[i]; + std::cout << "Generating context cache model for: " << ToUTF8String(model_path) << std::endl; + if (i == total_file_count - 1) { + so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); + } + Ort::Session session(env, model_path.c_str(), so); + } + } + + std::cout << "Start to update the generated Onnx model." << std::endl; + std::vector> ep_ctx_files; + ep_ctx_files.reserve(test_config.model_file_paths.size()); + for (auto model_path : test_config.model_file_paths) { + auto pos = model_path.find_last_of(ORT_TSTR(".")); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + ORT_TSTR("_ctx.onnx"); + } else { + model_path = model_path + ORT_TSTR("_ctx.onnx"); + } + ep_ctx_files.push_back(model_path); + } + + // Get the last context binary file name + std::string last_qnn_ctx_binary_file_name; + int64_t max_size = 0; + GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); + std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; + if (last_qnn_ctx_binary_file_name.empty()) { + throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); + } + ep_ctx_files.pop_back(); + + // Update generated context cache Onnx model to make the main EPContext node point to + // the last QNN context binary file + // Remove not used QNN context binary file, only keep the last one only which contains all graphs + UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); + } + ORT_CATCH(const Ort::Exception& e) { + std::cerr << "Failed to generate context cache file: " << e.what(); + return -1; + } + + std::cout << "Generation done!"; + return 0; +} + +#ifdef _WIN32 +int wmain(int argc, wchar_t* argv[]) { +#else +int main(int argc, char* argv[]) { +#endif + int retval = -1; + ORT_TRY { + retval = real_main(argc, argv); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + fprintf(stderr, "%s\n", ex.what()); + retval = -1; + }); + } + + ::google::protobuf::ShutdownProtobufLibrary(); + + return retval; +} diff --git a/onnxruntime/test/qnn_ctx_gen/test_configuration.h b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h similarity index 75% rename from onnxruntime/test/qnn_ctx_gen/test_configuration.h rename to onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h index bf4c7061a3484..198d03211f561 100644 --- a/onnxruntime/test/qnn_ctx_gen/test_configuration.h +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h @@ -14,15 +14,20 @@ namespace onnxruntime { namespace qnnctxgen { +struct MachineConfig { + std::string provider_type_name{onnxruntime::kQnnExecutionProvider}; +}; + struct RunConfig { bool f_verbose{false}; std::unordered_map session_config_entries; - std::unordered_map qnn_options; + std::unordered_map provider_options; }; struct TestConfig { std::vector> model_file_paths; RunConfig run_config; + MachineConfig machine_config; }; } // namespace qnnctxgen diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 1b06eb55afbd2..95101c8075fc2 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -138,6 +138,7 @@ class FuseExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override { // Fuse two add into one. std::vector> result; diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index b6b915f90d99a..8f4eede76b905 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -27,6 +27,7 @@ #include "test/util/include/default_providers.h" #include "test/util/include/file_util.h" #include "core/optimizer/layout_transformation/layout_transformation.h" +#include "core/optimizer/graph_optimizer_registry.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -264,7 +265,11 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - GraphPartitioner partitioner(krm, execution_providers); + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&sess_options, + execution_providers.Get(onnxruntime::kCpuExecutionProvider), + &DefaultLoggingManager().DefaultLogger()); + GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); ASSERT_STATUS_OK( partitioner.Partition( graph, session_state.GetMutableFuncMgr(), @@ -350,8 +355,12 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&sess_options, + execution_providers.Get(onnxruntime::kCpuExecutionProvider), + &DefaultLoggingManager().DefaultLogger()); // Partition the graph - GraphPartitioner partitioner(krm, execution_providers); + GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); ASSERT_STATUS_OK(partitioner.Partition( graph, session_state.GetMutableFuncMgr(), [&cpu_allocator](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, @@ -409,8 +418,13 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&sess_options, + execution_providers.Get(onnxruntime::kCpuExecutionProvider), + &DefaultLoggingManager().DefaultLogger()); + // Partition the graph - GraphPartitioner partitioner(krm, execution_providers); + GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); ASSERT_STATUS_OK(partitioner.Partition( graph, session_state.GetMutableFuncMgr(), [&cpu_allocator](Graph& graph, bool& modified, @@ -479,7 +493,12 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, SessionState session_state(model->MainGraph(), execution_providers, tp.get(), nullptr, dtm, edlm, default_logger, profiler, sess_options); - GraphPartitioner partitioner(krm, execution_providers); + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&sess_options, + execution_providers.Get(onnxruntime::kCpuExecutionProvider), + &DefaultLoggingManager().DefaultLogger()); + + GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); layout_transformation::TransformLayoutFunction transform_layout_fn; layout_transformation::DebugGraphFn debug_graph_fn; ASSERT_STATUS_OK( diff --git a/onnxruntime/test/framework/type_info_test.cc b/onnxruntime/test/framework/type_info_test.cc index ee787fb071d97..d8ef668bf1c7e 100644 --- a/onnxruntime/test/framework/type_info_test.cc +++ b/onnxruntime/test/framework/type_info_test.cc @@ -22,9 +22,9 @@ TEST(TypeInfoTests, TensorProto) { auto tensor_type_info = OrtTypeInfo::FromTypeProto(tensor_type.value); ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info->type); - ASSERT_NE(nullptr, tensor_type_info->data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->data->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info->tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->tensor_type_info->shape.GetDims())); } TEST(TypeInfoTests, SequenceWithTensorElement) { @@ -37,9 +37,9 @@ TEST(TypeInfoTests, SequenceWithTensorElement) { const auto& tensor_type_info = *seq_type_info->sequence_type_info->sequence_key_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); - ASSERT_NE(nullptr, tensor_type_info.data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info.tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->shape.GetDims())); } TEST(TypeInfoTests, OptionalWithTensorProto) { @@ -54,9 +54,9 @@ TEST(TypeInfoTests, OptionalWithTensorProto) { const auto& contained_type = *optional_type_info->optional_type_info->contained_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, contained_type.type); - ASSERT_NE(nullptr, contained_type.data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.data->shape.GetDims())); + ASSERT_NE(nullptr, contained_type.tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.tensor_type_info->shape.GetDims())); } #if !defined(DISABLE_ML_OPS) @@ -74,11 +74,11 @@ TEST(TypeInfoTests, MapWithTensorValue) { const auto& tensor_type_info = *map_info.map_value_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); - ASSERT_NE(nullptr, tensor_type_info.data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info.tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->shape.GetDims())); } #endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 6bfe7bc3856ba..eecff3fa4d8ff 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -174,7 +174,7 @@ static std::unique_ptr MakeSparseTensor(MLDataType data_type, cons return p_tensor; } -void BaseTester::CopyDataToTensor(gsl::span data, Tensor& dst) { +void BaseTester::CopyDataToTensor(gsl::span data, Tensor& dst) { ORT_ENFORCE(dst.SizeInBytes() >= data.size_bytes(), "Not enough space in the destination tensor"); memcpy(dst.MutableDataRaw(), data.data(), data.size_bytes()); } @@ -203,7 +203,7 @@ void BaseTester::AddSparseCooTensorData(std::vector& data, MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span indices, const ValidateOutputParams& check_params, const std::vector* dim_params) { @@ -247,7 +247,7 @@ void BaseTester::AddSparseCsrTensorData(std::vector& data, MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span inner_indices, gsl::span outer_indices, const ValidateOutputParams& check_params, diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h index 512b3402c5986..d39cc3c750dec 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/providers/base_tester.h @@ -868,7 +868,7 @@ class BaseTester { void AddShapeToTensorData(NodeArg& node_arg, gsl::span dims, const std::vector* dim_params); - void CopyDataToTensor(gsl::span data, Tensor& dst); + void CopyDataToTensor(gsl::span data, Tensor& dst); #if !defined(DISABLE_SPARSE_TENSORS) NodeArg MakeSparseNodeArg(int32_t dtype, const char* name, @@ -879,7 +879,7 @@ class BaseTester { MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span indices, const ValidateOutputParams& check_params, const std::vector* dim_params = nullptr); @@ -895,7 +895,7 @@ class BaseTester { MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span inner_indices, gsl::span outer_indices, const ValidateOutputParams& check_params, diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 6f7930f722564..1c6375ebdb0b1 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -170,11 +170,11 @@ TEST(SoftmaxOperator, ThreeAndFourDimsAxis0) { RunTest(input_vals_60, expected_vals, three_dimensions, /*opset*/ 7, /*axis*/ 0, // axis=0 is not supported by TensorRT - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); RunTest(input_vals_60, expected_vals, four_dimensions, /*opset*/ 7, /*axis*/ 0, // axis=0 is not supported by TensorRT - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); } TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis) { @@ -201,10 +201,10 @@ TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis) { 0.040478885f, 0.033857856f, 0.080346674f, 0.06199841f, 0.040481992f}; RunTest(input_vals_60, expected_vals, three_dimensions, /*opset*/ 7, /*axis*/ 1, - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); RunTest(input_vals_60, expected_vals, four_dimensions, /*opset*/ 7, /*axis*/ 2, - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); } TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis_opset13) { @@ -376,8 +376,9 @@ TEST(SoftmaxOperator, DimWithZero) { RunTest(x_vals, expected_vals, dimensions, /*opset*/ -1, /*axis*/ 0, {kTensorrtExecutionProvider, - kNnapiExecutionProvider, // NNAPI softmax does not support empty input - kQnnExecutionProvider} // QNN doesn't support dim 0 + kNnapiExecutionProvider, // NNAPI softmax does not support empty input + kWebGpuExecutionProvider, // WebGPU does not support dim 0 + kQnnExecutionProvider} // QNN doesn't support dim 0 ); } diff --git a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc index a5378fa3cefd7..c98d9e28b2f46 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc @@ -254,5 +254,45 @@ TEST(ConvIntegerTest, WithStride3_2D_u8u8) { test.Run(); } +TEST(ConvIntegerTest, NoXZeroPoint) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10}); + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {2, 2, + 2, 2}); + test.AddOptionalInputEdge(); + test.AddInput("w_zero_point", {}, {1}); + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {16, 20, + 28, 32}); + test.Run(); +} + +// provide optional input with empty name for w. tests that input args == 4 but the w_zero_point does not exist. +TEST(ConvIntegerTest, NoWZeroPoint) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10}); + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {2, 2, + 2, 2}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOptionalInputEdge(); + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {24, 32, + 48, 56}); + test.Run(); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index b753bc386d722..ee0aff6d26444 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -111,6 +111,7 @@ DataLayout InternalTestingExecutionProvider::GetPreferredLayout() const { std::vector> InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // find nodes that have ops in our supported list std::unordered_set supported_static_nodes; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index d2ed8259ee974..0caa0febc2796 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -20,6 +20,7 @@ class InternalTestingExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 07843c30a61df..3dec74599abdf 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -43,6 +43,35 @@ static const std::string& GetNodeAttr(const Node& node, const std::string& attr_ return default_val; } +// from the context cache Onnx model, find the EPContext node with main_context=1, +// and get the QNN context binary file name +static void GetContextBinaryFileName(const std::string onnx_ctx_file, + std::string& last_ctx_bin_file, + const Logger& logger) { + std::shared_ptr ctx_model; + ASSERT_STATUS_OK(Model::Load(ToPathString(onnx_ctx_file), ctx_model, nullptr, logger)); + auto& ctx_graph = ctx_model->MainGraph(); + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); + if (1 == is_main_context) { + last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); + return; + } + } + } +} + +// Get context binary file name from Context model file and remove it with the context model file +void CleanUpCtxFile(std::string context_file_path) { + std::string qnn_ctx_binary_file_name; + GetContextBinaryFileName(context_file_path, qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); + + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name.c_str()), 0); + ASSERT_EQ(std::remove(context_file_path.c_str()), 0); +} + // Create a model with FusedMatMul + Add (quantized) // input1 -> Add -> Q -> DQ ---- // | @@ -123,22 +152,22 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - const std::string context_binary_file = "./qnn_context_binary_multi_partition_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_multi_partition_test.onnx"; + std::remove(context_model_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); int ep_context_node_count = 0; int non_ep_context_node_count = 0; std::shared_ptr ctx_model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); auto& ctx_graph = ctx_model->MainGraph(); for (auto& node : ctx_graph.Nodes()) { if (node.OpType() == "EPContext") { @@ -156,7 +185,7 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { Ort::SessionOptions so2; // context file path is required if it's non-embed mode and the model is loaded from memory - so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so2.AppendExecutionProvider("QNN", provider_options); std::string ctx_model_data; @@ -164,7 +193,7 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { Ort::Session session2(*ort_env, ctx_model_data.data(), ctx_model_data.size(), so2); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary @@ -237,7 +266,7 @@ void EpCtxCpuNodeWithExternalIniFileTestBody(bool expect_external_ini_file) { // clean up ASSERT_EQ(std::remove(model_with_ext.c_str()), 0); ASSERT_EQ(std::remove(model_ext_file_full_path.c_str()), 0); - ASSERT_EQ(std::remove(ep_context_model_file.c_str()), 0); + CleanUpCtxFile(ep_context_model_file); } // Set the external initializer size threshold to 1024 so FusedMatMul (which fallback on CPU) @@ -333,7 +362,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx"; - const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin"; + const std::string ep_context_binary_file = "./ep_context_no_over_write_QNN_10880527342279992768_1_0.bin"; std::remove(ep_context_onnx_file.c_str()); Ort::SessionOptions so; @@ -444,21 +473,21 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; + std::remove(context_model_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Generate context cache model from the ONNX models with 2 inputs. @@ -481,26 +510,26 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + const std::string context_model_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); auto inputs = model->MainGraph().GetInputs(); EXPECT_TRUE(inputs.size() == 2); EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { @@ -519,20 +548,20 @@ TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + const std::string context_model_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AddConfigEntry(kOrtSessionOptionEpContextNodeNamePrefix, node_name_prefix.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); for (auto& node : model->MainGraph().Nodes()) { if (node.OpType() == "EPContext") { EXPECT_TRUE(node.Name().find(node_name_prefix) != std::string::npos); @@ -540,7 +569,7 @@ TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { } // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Run QDQ model on HTP 3 times @@ -554,12 +583,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_binary_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_test.onnx"; + std::remove(context_model_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); const std::string op_type = "Atan"; @@ -577,9 +606,11 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { session_option_pairs); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // 2nd run directly loads and run from Qnn context cache model + std::unordered_map session_option_pairs2; + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, @@ -587,9 +618,10 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file); + context_model_file, + session_option_pairs2); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Run QDQ model on HTP 3 times @@ -604,7 +636,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; const std::string context_binary_file = "./testdata/qnn_context_cache_non_embed.onnx"; - std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); @@ -686,7 +718,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { #endif provider_options["offload_graph_io_quantization"] = "0"; const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; std::remove(context_binary_file.c_str()); std::remove(context_bin.string().c_str()); @@ -828,6 +860,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, "./qnn_context_not_exist.onnx")); RunOptions run_options; run_options.run_tag = so.session_logid; @@ -841,7 +874,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); @@ -854,6 +887,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, "./test_ctx.onnx")); RunOptions run_options; run_options.run_tag = so.session_logid; @@ -867,7 +901,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); @@ -884,12 +918,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_2inputs_test.onnx"; + std::remove(context_model_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); @@ -908,9 +942,11 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { session_option_pairs); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // 2nd run directly loads and run from Qnn context cache model + std::unordered_map session_option_pairs2; + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), provider_options, @@ -918,9 +954,10 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file); + context_model_file, + session_option_pairs2); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Context binary only contains a single QNN graph, generated context cache model (detached mode) only has 1 EPContext node @@ -936,14 +973,14 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; + std::remove(context_model_file.c_str()); std::remove(context_bin.string().c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); @@ -962,7 +999,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName session_option_pairs); // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // Check the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_bin)); @@ -990,18 +1027,19 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str())); RunOptions run_options; run_options.run_tag = so.session_logid; InferenceSessionWrapper session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::OK); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(context_model_file.c_str()), 0); ASSERT_EQ(std::remove(context_bin.string().c_str()), 0); } @@ -1053,44 +1091,20 @@ static void CreateQdqModel(const std::string& model_file_name, const Logger& log static void DumpModelWithSharedCtx(const ProviderOptions& provider_options, const std::string& onnx_model_path1, const std::string& onnx_model_path2) { - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); - RunOptions run_options; - run_options.run_tag = so.session_logid; - - auto qnn_ep = QnnExecutionProviderWithOptions(provider_options, &so); - std::shared_ptr qnn_ep_shared(std::move(qnn_ep)); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); - InferenceSessionWrapper session_object1{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object1.RegisterExecutionProvider(qnn_ep_shared)); - ASSERT_STATUS_OK(session_object1.Load(ToPathString(onnx_model_path1))); - ASSERT_STATUS_OK(session_object1.Initialize()); + so.AppendExecutionProvider("QNN", provider_options); - InferenceSessionWrapper session_object2{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(qnn_ep_shared)); - ASSERT_STATUS_OK(session_object2.Load(ToPathString(onnx_model_path2))); - ASSERT_STATUS_OK(session_object2.Initialize()); -} + // Create 2 sessions to generate context binary models, the 1st session will share the QnnBackendManager + // to the 2nd session, so graphs from these 2 models are all included in the 2nd context binary + Ort::Session session1(*ort_env, ToPathString(onnx_model_path1).c_str(), so); -// from the last context ache Onnx model, find the EPContext node with main_context=1, -// and get the QNN context binary file name, thie context binary contains all graphs from all Onnx models -static void GetLastContextBinaryFileName(const std::string last_onnx_ctx_file, - std::string& last_ctx_bin_file, - const Logger& logger) { - std::shared_ptr ctx_model; - ASSERT_STATUS_OK(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, logger)); - auto& ctx_graph = ctx_model->MainGraph(); - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - if (1 == is_main_context) { - last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); - return; - } - } - } + so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); + Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); } // Update generated context cache Onnx model to make the main EPContext node point to @@ -1167,15 +1181,21 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) { for (auto model_path : onnx_model_paths) { CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); - ctx_model_paths.push_back(model_path + "_ctx.onnx"); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); } DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); - // Get the last context binary file name + // Get the last context binary file name, the latest context binary file holds all graphs generated from all models std::string last_qnn_ctx_binary_file_name; - GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); + GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty()); // Update generated context cache Onnx model to make the main EPContext node point to @@ -1265,15 +1285,21 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { for (auto model_path : onnx_model_paths) { CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); - ctx_model_paths.push_back(model_path + "_ctx.onnx"); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); } DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); // Get the last context binary file name std::string last_qnn_ctx_binary_file_name; - GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); + GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty()); // Update generated context cache Onnx model to make the main EPContext node point to @@ -1336,6 +1362,69 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { } std::remove(last_qnn_ctx_binary_file_name.c_str()); } + +// For Ort sessions to generate the context binary, with session option ep.share_ep_contexts enabled +// Ort sessions will share the QnnBackendManager, so that all graphs from all models compile into the same Qnn context +TEST_F(QnnHTPBackendTests, QnnContextGenWeightSharingSessionAPI) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["offload_graph_io_quantization"] = "0"; + + // Create QDQ models + std::vector onnx_model_paths{"./weight_share1.onnx", "./weight_share2.onnx"}; + std::vector ctx_model_paths; + for (auto model_path : onnx_model_paths) { + CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); + } + + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session1(*ort_env, ToPathString(onnx_model_paths[0]).c_str(), so); + std::string qnn_ctx_binary_file_name1; + GetContextBinaryFileName(ctx_model_paths[0], qnn_ctx_binary_file_name1, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name1.empty()); + + // Tell the EP stop share the QnnBackendManager from this session then on + so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); + Ort::Session session2(*ort_env, ToPathString(onnx_model_paths[1]).c_str(), so); + std::string qnn_ctx_binary_file_name2; + GetContextBinaryFileName(ctx_model_paths[1], qnn_ctx_binary_file_name2, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name2.empty()); + + auto file_size_1 = std::filesystem::file_size(qnn_ctx_binary_file_name1); + auto file_size_2 = std::filesystem::file_size(qnn_ctx_binary_file_name2); + EXPECT_TRUE(file_size_2 > file_size_1); + + // clean up + for (auto model_path : onnx_model_paths) { + ASSERT_EQ(std::remove(model_path.c_str()), 0); + } + for (auto ctx_model_path : ctx_model_paths) { + ASSERT_EQ(std::remove(ctx_model_path.c_str()), 0); + } + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name1.c_str()), 0); + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name2.c_str()), 0); +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index e2deccc4fff0f..2361e179d1cf1 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -14,6 +14,7 @@ #include "core/framework/compute_capability.h" #include "core/graph/graph.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/optimizer/graph_optimizer_registry.h" namespace onnxruntime { namespace test { @@ -279,9 +280,10 @@ static BackendSupport GetHTPSupport(const onnxruntime::logging::Logger& logger) onnxruntime::GraphViewer graph_viewer(graph); std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( {{"backend_path", "QnnHtp.dll"}, {"offload_graph_io_quantization", "0"}}); + GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability qnn_ep->SetLogger(&logger); - auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, nullptr); + auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; } @@ -342,9 +344,10 @@ static BackendSupport GetCPUSupport(const onnxruntime::logging::Logger& logger) onnxruntime::GraphViewer graph_viewer(graph); std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( {{"backend_path", "QnnCpu.dll"}, {"offload_graph_io_quantization", "0"}}); + GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability qnn_ep->SetLogger(&logger); - auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, nullptr); + auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; } diff --git a/onnxruntime/test/python/quantization/test_get_qdq_config.py b/onnxruntime/test/python/quantization/test_get_qdq_config.py index 25f058d8f6eac..4a71b3694822c 100644 --- a/onnxruntime/test/python/quantization/test_get_qdq_config.py +++ b/onnxruntime/test/python/quantization/test_get_qdq_config.py @@ -156,6 +156,62 @@ def should_exclude_node_(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: self.assertTrue(bool(expected_excluded_nodes)) self.assertEqual(set(qdq_config.nodes_to_exclude), expected_excluded_nodes) + def test_op_types_to_quantize(self): + """ + Test that get_qdq_config() returns a config that sets the op_types_to_quantize arg. + """ + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # No op_types_to_quantize arg means all ops are quantized. + qdq_config = get_qdq_config(float_model, data_reader, op_types_to_quantize=None) + self.assertEqual(set(qdq_config.op_types_to_quantize), {"Add"}) + + # specify custom op_types_to_quantize arg. + qdq_config = get_qdq_config(float_model, data_reader, op_types_to_quantize=["Mul"]) + self.assertEqual(set(qdq_config.op_types_to_quantize), {"Mul"}) + + # exclude op_type indirectly by specifying nodes_to_exclude arg. + qdq_config = get_qdq_config( + float_model, + data_reader, + nodes_to_exclude=[node.name for node in float_model.graph.node if node.op_type == "Add"], + ) + self.assertEqual(set(qdq_config.op_types_to_quantize), set()) + + def test_calibration_providers(self): + """ + Test that get_qdq_config() returns a config that sets the calibration providers arg. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + qdq_config = get_qdq_config( + float_model, + data_reader, + calibration_providers=["CPUExecutionProvider"], + ) + self.assertEqual(qdq_config.calibration_providers, ["CPUExecutionProvider"]) + def test_external_data(self): """ Test that get_qdq_config() returns a config that enables external data diff --git a/onnxruntime/test/qnn_ctx_gen/main.cc b/onnxruntime/test/qnn_ctx_gen/main.cc deleted file mode 100644 index bb5007b40b072..0000000000000 --- a/onnxruntime/test/qnn_ctx_gen/main.cc +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// onnxruntime dependencies -#include "test_configuration.h" -#include -#include -#include -#include "command_args_parser.h" -#include - -#include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/inference_session.h" -#include "core/session/ort_env.h" -#include "core/providers/provider_factory_creators.h" -#include "core/common/logging/sinks/clog_sink.h" - -#include "core/graph/model.h" -#include "core/session/environment.h" -#include "core/common/logging/logging.h" - -using namespace onnxruntime; -const OrtApi* g_ort = NULL; -std::unique_ptr ort_env; - -static void CheckStatus(const Status& status) { - if (status.Code() != common::StatusCode::OK) { - std::string msg = status.ErrorMessage(); - throw Ort::Exception(std::move(msg), OrtErrorCode::ORT_FAIL); - } -} - -static int64_t GetNodeAttr(const Node& node, const std::string& attr_name, int64_t default_val) { - const auto& attributes = node.GetAttributes(); - if (auto entry = attributes.find(attr_name); entry != attributes.end()) { - return entry->second.i(); - } - - return default_val; -} - -static const std::string& GetNodeAttr(const Node& node, const std::string& attr_name, const std::string& default_val) { - const auto& attributes = node.GetAttributes(); - if (auto entry = attributes.find(attr_name); entry != attributes.end()) { - return entry->second.s(); - } - - return default_val; -} - -// from the last context cache Onnx model, find the EPContext node with main_context=1, -// and get the QNN context binary file name, this context binary contains all graphs from all Onnx models -// get the max spill fill buffer size -static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, - std::string& last_ctx_bin_file, - int64_t& max_size) { - max_size = 0; - std::shared_ptr ctx_model; - CheckStatus(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, - (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); - auto& ctx_graph = ctx_model->MainGraph(); - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - max_size = GetNodeAttr(node, "max_size", static_cast(0)); - if (1 == is_main_context) { - last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); - return; - } - } - } -} - -// Update generated context cache Onnx model to make the main EPContext node point to -// the last QNN context binary file -// Remove not used QNN context binary file, only keep the last one which contains all graphs -static void UpdateEpContextModel(const std::vector>& ep_ctx_files, - const std::string& last_qnn_ctx_binary_file_name, - int64_t max_size) { - for (auto ep_ctx_file : ep_ctx_files) { - std::shared_ptr ctx_model; - auto path_str = ToPathString(ep_ctx_file); - CheckStatus(Model::Load(path_str, ctx_model, nullptr, - (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); - auto& ctx_graph = ctx_model->MainGraph(); - GraphViewer graph_viewer(ctx_graph); - auto path = std::filesystem::path(path_str); - - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - if (1 == is_main_context) { - std::string old_qnn_ctx_binary_file_name = GetNodeAttr(node, "ep_cache_context", ""); - auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); - std::remove(file_path.string().c_str()); - node.ClearAttribute("ep_cache_context"); - node.AddAttribute("ep_cache_context", last_qnn_ctx_binary_file_name); - node.ClearAttribute("max_size"); - node.AddAttribute("max_size", max_size); - } - } - } - std::remove(ToUTF8String(ep_ctx_file).c_str()); - CheckStatus(Model::Save(*ctx_model.get(), ToPathString(ep_ctx_file))); - } -} - -#ifdef _WIN32 -int real_main(int argc, wchar_t* argv[]) { -#else -int real_main(int argc, char* argv[]) { -#endif - g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); - qnnctxgen::TestConfig test_config; - if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { - qnnctxgen::CommandLineParser::ShowUsage(); - return -1; - } - - { - bool failed = false; - ORT_TRY { - OrtLoggingLevel logging_level = test_config.run_config.f_verbose - ? ORT_LOGGING_LEVEL_VERBOSE - : ORT_LOGGING_LEVEL_WARNING; - - ort_env = std::make_unique(logging_level, "Default"); - } - ORT_CATCH(const Ort::Exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "Error creating environment. Error-> %s \n", e.what()); - failed = true; - }); - } - - if (failed) - return -1; - } - - ORT_TRY { - SessionOptions so; - so.session_logid = "qnn_ctx_gen_session_logger"; - // Set default session option to dump QNN context model with non-embed mode - CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); - CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); - RunOptions run_options; - run_options.run_tag = so.session_logid; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - // set default QNN EP option to enable weight sharing - provider_options["enable_htp_weight_sharing"] = "1"; - - for (auto it : test_config.run_config.qnn_options) { - provider_options[it.first] = it.second; - } - - for (auto it : test_config.run_config.session_config_entries) { - if (it.first == kOrtSessionOptionEpContextEnable && it.second != "1") { - std::cerr << "Need to enable ep context cache." << std::endl; - continue; - } - if (it.first == kOrtSessionOptionEpContextEmbedMode && it.second != "0") { - std::cerr << "Only support non-embed model for weight sharing." << std::endl; - continue; - } - if (it.first == kOrtSessionOptionEpContextFilePath) { - std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; - continue; - } - CheckStatus(so.config_options.AddConfigEntry(it.first.c_str(), it.second.c_str())); - } - - for (auto model_path : test_config.model_file_paths) { - std::cout << "Model file path: " << ToUTF8String(model_path) << std::endl; - } - - // Generate context cache model files with QNN context binary files - // The context binary file generated later includes all graphs from previous models - { - auto ep = QNNProviderFactoryCreator::Create(provider_options, &so)->CreateProvider(); - std::shared_ptr qnn_ep(std::move(ep)); - - for (auto model_path : test_config.model_file_paths) { - std::cout << "Generate context cache model for: " << ToUTF8String(model_path) << std::endl; - InferenceSession session_object{so, ((OrtEnv*)*ort_env.get())->GetEnvironment()}; - CheckStatus(session_object.RegisterExecutionProvider(qnn_ep)); - CheckStatus(session_object.Load(ToPathString(model_path))); - CheckStatus(session_object.Initialize()); - } - } - - std::cout << "Start to update the generated Onnx model." << std::endl; - std::vector> ep_ctx_files; - ep_ctx_files.reserve(test_config.model_file_paths.size()); - for (auto model_path : test_config.model_file_paths) { - ep_ctx_files.push_back(model_path + ORT_TSTR("_ctx.onnx")); - } - - // Get the last context binary file name - std::string last_qnn_ctx_binary_file_name; - int64_t max_size = 0; - GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); - std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; - if (last_qnn_ctx_binary_file_name.empty()) { - throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); - } - ep_ctx_files.pop_back(); - - // Update generated context cache Onnx model to make the main EPContext node point to - // the last QNN context binary file - // Remove not used QNN context binary file, only keep the last one which contains all graphs - UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); - } - ORT_CATCH(const Ort::Exception& e) { - fprintf(stderr, "Failed to generate context cache file: %s \n", e.what()); - - ort_env.reset(); - return -1; - } - - ort_env.reset(); - - return 0; -} - -#ifdef _WIN32 -int wmain(int argc, wchar_t* argv[]) { -#else -int main(int argc, char* argv[]) { -#endif - int retval = -1; - ORT_TRY { - retval = real_main(argc, argv); - } - ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "%s\n", ex.what()); - retval = -1; - }); - } - - ::google::protobuf::ShutdownProtobufLibrary(); - - return retval; -} diff --git a/onnxruntime/test/shared_lib/custom_op_utils.cc b/onnxruntime/test/shared_lib/custom_op_utils.cc index bf7efacdbb505..a624479bcd00b 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.cc +++ b/onnxruntime/test/shared_lib/custom_op_utils.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "gtest/gtest.h" #include "custom_op_utils.h" @@ -639,3 +640,22 @@ void StandaloneCustomKernel::Compute(OrtKernelContext* context) { StandaloneCustomKernel::~StandaloneCustomKernel() { } + +OrtStatusPtr CustomCastKernel::ComputeV2(OrtKernelContext* context) { + Ort::KernelContext ctx(context); + + auto in = ctx.GetInput(0); + std::vector shape = in.GetTensorTypeAndShapeInfo().GetShape(); + int64_t num_elements = std::accumulate(shape.cbegin(), shape.cend(), int64_t(1), std::multiplies()); + + // CustomCast::GetInputType constraint ensures we only get float input + const float* data = in.GetTensorData(); + double* out_data = ctx.GetOutput(0, shape).GetTensorMutableData(); + gsl::span input_span(data, num_elements); + gsl::span output_span(out_data, num_elements); + + std::transform(input_span.begin(), input_span.end(), output_span.begin(), + [](float val) { return static_cast(val); }); + + return nullptr; +} diff --git a/onnxruntime/test/shared_lib/custom_op_utils.h b/onnxruntime/test/shared_lib/custom_op_utils.h index e11540aaa5691..424c2e2fe3a08 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.h +++ b/onnxruntime/test/shared_lib/custom_op_utils.h @@ -8,12 +8,6 @@ #include #endif -struct Input { - const char* name = nullptr; - std::vector dims; - std::vector values; -}; - struct MyCustomKernel { MyCustomKernel(const OrtApi& ort_api, const OrtKernelInfo* /*info*/) : ort_(ort_api) { @@ -464,4 +458,63 @@ struct MulTopOpFloat16 : Ort::CustomOpBase OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL; } -}; \ No newline at end of file +}; + +// +// Example overriding an operator where type inference is required for the output so kernel matching works correctly +// +struct CustomCastKernel { + CustomCastKernel(const OrtApi& /*ort_api*/, const OrtKernelInfo* /*info*/) + /*: ort_(ort_api)*/ { + } + + OrtStatusPtr ComputeV2(OrtKernelContext* context); + + private: + // const OrtApi& ort_; +}; + +// Custom Cast op that takes float input and converts based on 'to' attribute. +// Example implementation only supports cast to double. +struct CustomCast : Ort::CustomOpBase { + explicit CustomCast(const char* provider) : provider_(provider) { + // if overriding an ONNX op you need to set the opset versions you are overriding + start_ver_ = 7; // should match minimum ONNX schema you implement + // end_ver_ = ...; should match maximum ONNX schema you implement or unset for unlimited. + } + + // static method used by Ort::CustomOpBase::SetShapeInferFn + static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context) { + auto shape = context.GetInputShape(0); + + // infer output type based on 'to'. + auto to = context.GetAttrInt("to"); + if (to != ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + return Ort::Status("Unexpected type", ORT_INVALID_ARGUMENT).release(); + } + + context.SetOutputShape(0, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE); + return nullptr; + } + + OrtStatusPtr CreateKernelV2(const OrtApi& api, const OrtKernelInfo* info, void** op_kernel) const { + Ort::ConstKernelInfo ki(info); + *op_kernel = new CustomCastKernel(api, info); + return nullptr; + }; + + const char* GetName() const { return "Cast"; }; + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { + // example only accepts float input + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + size_t GetOutputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + private: + const char* provider_{"CPUExecutionProvider"}; +}; diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index ca9ca0f82a25a..b517ba7032886 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1,17 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include -#include -#include +#include #include +#include +#include +#include #include -#include +#include #include +#include #include +#include + #include "gtest/gtest.h" #include "gmock/gmock.h" @@ -25,13 +27,13 @@ #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/util/thread_utils.h" -#include "onnxruntime_config.h" -#include "providers.h" -#include "test_allocator.h" -#include "test_fixture.h" -#include "utils.h" -#include "custom_op_utils.h" -#include +#include "test/shared_lib/custom_op_utils.h" +#include "test/shared_lib/test_fixture.h" +#include "test/shared_lib/utils.h" +#include "test/util/include/providers.h" +#include "test/util/include/test_allocator.h" + +#include "onnxruntime_config.h" // generated file in build output dir #ifdef _WIN32 #include @@ -63,48 +65,6 @@ constexpr size_t countof(T (&)[N]) { return N; } extern std::unique_ptr ort_env; -template -void RunSession(OrtAllocator* allocator, Ort::Session& session_object, - const std::vector& inputs, - const char* output_name, - const std::vector& dims_y, - const std::vector& values_y, - Ort::Value* output_tensor) { - std::vector ort_inputs; - std::vector input_names; - for (size_t i = 0; i < inputs.size(); i++) { - input_names.emplace_back(inputs[i].name); - ort_inputs.emplace_back( - Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), - inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); - } - - std::vector ort_outputs; - if (output_tensor) - session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), - &output_name, output_tensor, 1); - else { - ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), - &output_name, 1); - ASSERT_EQ(ort_outputs.size(), 1u); - output_tensor = &ort_outputs[0]; - } - - auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); - ASSERT_EQ(type_info.GetShape(), dims_y); - size_t total_len = type_info.GetElementCount(); - ASSERT_EQ(values_y.size(), total_len); - - OutT* f = output_tensor->GetTensorMutableData(); - for (size_t i = 0; i != total_len; ++i) { - if constexpr (std::is_same::value || std::is_same::value) { - ASSERT_NEAR(values_y[i], f[i], 1e-3); - } else { - ASSERT_EQ(values_y[i], f[i]); - } - } -} - #ifdef USE_DML struct DmlObjects { ComPtr d3d12_device; @@ -300,12 +260,12 @@ Ort::Value CreateTensorValueFromExistingD3DResource( #endif -template +template > static void TestInference(Ort::Env& env, const std::basic_string& model_uri, const std::vector& inputs, const char* output_name, const std::vector& expected_dims_y, - const std::vector& expected_values_y, + const std::vector& expected_values_y, int provider_type, OrtCustomOpDomain* custom_op_domain_ptr, const ORTCHAR_T* custom_op_library_filename, @@ -362,26 +322,26 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod auto default_allocator = std::make_unique(); // without preallocated output tensor - RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims_y, - expected_values_y, - nullptr); + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + nullptr); // with preallocated output tensor - Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), - expected_dims_y.data(), expected_dims_y.size()); + Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), + expected_dims_y.data(), expected_dims_y.size()); // test it twice for (int i = 0; i != 2; ++i) - RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims_y, - expected_values_y, - &value_y); + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + &value_y); } } @@ -450,8 +410,8 @@ class CApiTestWithProvider : public testing::Test, public ::testing::WithParamIn TEST_P(CApiTestWithProvider, simple) { // simple inference test // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -621,8 +581,8 @@ TEST(CApiTest, SparseInputModel) { TEST(CApiTest, custom_op_handler) { std::cout << "Running custom op inference" << std::endl; - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -657,8 +617,8 @@ TEST(CApiTest, custom_op_handler) { TEST(CApiTest, custom_op_set_input_memory_type) { std::cout << "Running custom op inference" << std::endl; - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -687,8 +647,8 @@ TEST(CApiTest, custom_op_set_input_memory_type) { #if !defined(ORT_MINIMAL_BUILD) TEST(CApiTest, StandaloneOpHandler) { - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -811,7 +771,7 @@ TEST(CApiTest, test_enable_ort_customops_stringlower) { // test custom op which accepts float and double as inputs TEST(CApiTest, varied_input_custom_op_handler) { - std::vector inputs(2); + std::vector> inputs(2); inputs[0].name = "X"; inputs[0].dims = {3}; inputs[0].values = {2.0f, 3.0f, 4.0f}; @@ -1422,8 +1382,8 @@ TEST(CApiTest, custom_op_with_attributes_handler) { TEST(CApiTest, RegisterCustomOpForCPUAndCUDA) { std::cout << "Tests registration of a custom op of the same name for both CPU and CUDA EPs" << std::endl; - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -1531,7 +1491,7 @@ TEST(CApiTest, test_custom_op_openvino_wrapper_library) { // The custom op extracts the serialized .xml/.bin bytes and creates an in-memory OpenVINO model // during kernel creation. The custom op is passed an image of a hand-drawn "1" as an input during computation, which // is then inferenced using OpenVINO C++ APIs. - std::vector inputs(1); + std::vector> inputs(1); inputs[0].name = "Input3"; inputs[0].dims = {1, 1, 28, 28}; @@ -1630,7 +1590,7 @@ TEST(CApiTest, test_custom_op_library) { #endif std::cout << "Running inference using custom op shared library" << std::endl; - std::vector inputs(2); + std::vector> inputs(2); inputs[0].name = "input_1"; inputs[0].dims = {3, 5}; inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, @@ -1682,7 +1642,7 @@ TEST(CApiTest, DISABLED_test_custom_op_shape_infer_attr) { #else TEST(CApiTest, test_custom_op_shape_infer_attr) { #endif - std::vector inputs(1); + std::vector> inputs(1); inputs[0].name = "input_0"; inputs[0].dims = {5}; inputs[0].values = {1.f, 2.f, 3.f, 4.f, 5.f}; @@ -1715,7 +1675,7 @@ TEST(CApiTest, test_custom_op_library_copy_variadic) { #endif std::cout << "Running inference using custom op shared library" << std::endl; - std::vector inputs(2); + std::vector> inputs(2); inputs[0].name = "input_0"; inputs[0].dims = {15}; inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, @@ -1869,8 +1829,8 @@ void PrepareModule() { TEST(CApiTest, test_pyop) { std::call_once(my_module_flag, PrepareModule); - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1882,8 +1842,8 @@ TEST(CApiTest, test_pyop) { TEST(CApiTest, test_pyop_multi) { std::call_once(my_module_flag, PrepareModule); - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1895,8 +1855,8 @@ TEST(CApiTest, test_pyop_multi) { TEST(CApiTest, test_pyop_kwarg) { std::call_once(my_module_flag, PrepareModule); - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1920,7 +1880,7 @@ TEST(ReducedOpsBuildTest, test_excluded_ops) { // In reduced ops build, test a model containing ops not included in required_ops.config cannot be loaded. // See onnxruntime/test/testdata/reduced_build_test.readme.txt for more details of the setup constexpr PATH_TYPE model_uri = TSTR("testdata/reduced_build_test.onnx_model_with_excluded_ops"); - std::vector inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}}; + std::vector> inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}}; std::vector expected_dims_y = {3}; std::vector expected_values_y = {0.1f, 0.1f, 0.1f}; bool failed = false; @@ -3322,8 +3282,8 @@ TEST(CApiTest, TestSharedAllocators) { OrtEnv* env_ptr = (OrtEnv*)(*ort_env); // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -3509,8 +3469,8 @@ TEST(CApiTest, TestSharedAllocators) { TEST(CApiTest, TestSharingOfInitializerAndItsPrepackedVersion) { // simple inference test // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -3905,8 +3865,8 @@ TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) { // simple inference test // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -4845,4 +4805,32 @@ TEST(CApiTest, GenerateNodeStatsFile) { output_names, 1); } -#endif \ No newline at end of file +#endif + +// Test that creates a custom Cast kernel which requires type inference of the output type to work. +// Also demonstrates overriding an ONNX operator as we register the custom op in the ONNX domain. +TEST(CApiTest, custom_cast) { + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = "input"; + input.dims = {3, 4}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 4}; + std::vector expected_values_y = {1.0, 2.0, 3.0, 4.0, + -1.0, -2.0, -3.0, -4.0, + 1.0, 2.0, 3.0, 4.0}; + + CustomCast custom_op{onnxruntime::kCpuExecutionProvider}; + + Ort::CustomOpDomain custom_op_domain(""); // onnx domain is empty string + custom_op_domain.Add(&custom_op); + + // model with Cast from ONNX test data + TestInference(*ort_env, TSTR("testdata/cast_float_to_double.onnx"), + inputs, "output", expected_dims_y, expected_values_y, 0, + custom_op_domain, nullptr); +} diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc new file mode 100644 index 0000000000000..9807fcca06ed4 --- /dev/null +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -0,0 +1,701 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/common/narrow.h" +#include "core/graph/constants.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_lite_custom_op.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +#include "test/shared_lib/test_fixture.h" +#include "test/shared_lib/utils.h" +#include "test/util/include/test_allocator.h" + +#include "onnxruntime_config.h" // generated file in build output dir + +extern std::unique_ptr ort_env; + +using namespace Ort; + +namespace { + +Ort::Session CreateSession(Ort::Env& env, + Model& graph_api_model, + Ort::SessionOptions* session_options_for_test = nullptr) { + Ort::SessionOptions default_session_options; + Ort::SessionOptions& session_options = session_options_for_test ? *session_options_for_test + : default_session_options; + + // Set this to save the model if you want to debug. + // session_options.SetOptimizedModelFilePath(ORT_TSTR("model_builder_output.onnx")); + + Ort::Session session(env, graph_api_model, session_options); + + // Session should not require the model to stay alive so free it now to validate. + graph_api_model = Model(nullptr); + + return session; +} + +template +void TestInference(Ort::Session& session, + const std::vector>& inputs, + const char* output_name, + const std::vector& expected_dims, + const std::vector& expected_values) { + auto default_allocator = std::make_unique(); + + // without preallocated output tensor + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims, + expected_values, + nullptr); +} + +// Create OrtNode using the C API +OrtNode* CreateNode(const OrtModelEditorApi& api, + const char* operator_name, const char* node_name, + const gsl::span input_names, + const gsl::span output_names, + const gsl::span attributes = {}, + const char* domain_name = onnxruntime::kOnnxDomain) { + OrtNode* node = nullptr; + Ort::ThrowOnError(api.CreateNode(operator_name, domain_name, node_name, + input_names.data(), input_names.size(), + output_names.data(), output_names.size(), + attributes.data(), attributes.size(), + &node)); + return node; +} + +// convenience func to convert initalizer lists to gsl::span +OrtNode* CreateNode(const OrtModelEditorApi& api, + const char* operator_name, const char* node_name, + const std::initializer_list input_names, + const std::initializer_list output_names, + const std::initializer_list attributes = {}, + const char* domain_name = onnxruntime::kOnnxDomain) { + std::vector inputs(input_names); + std::vector outputs(output_names); + std::vector attrs(attributes); + return CreateNode(api, operator_name, node_name, inputs, outputs, attrs, domain_name); +} +} // namespace + +struct TestAllocator : public OrtAllocator { + TestAllocator() { + version = ORT_API_VERSION; + Info = [](const struct OrtAllocator* this_ptr) -> const struct OrtMemoryInfo* { + auto* test_allocator = static_cast(this_ptr); + return test_allocator->memory_info; + }; + + Free = [](struct OrtAllocator* allocator, void* p) -> void { + auto* test_allocator = static_cast(allocator); + // find the matching pointer and remove it + auto it = std::find_if(test_allocator->weights.begin(), test_allocator->weights.end(), + [p](const std::unique_ptr>& v) { return v->data() == p; }); + if (it == test_allocator->weights.end()) { + throw std::runtime_error("Free called with unknown pointer"); + } + + test_allocator->weights.erase(it); + }; + + Alloc = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { + throw std::runtime_error("This should not be used"); + }; + + Reserve = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { + throw std::runtime_error("This should not be used"); + }; + } + + // initializers that are used directly by the model. as there's no copy they must remain valid. + // we store them in the test allocator so we can validate that Free is called + std::vector>> weights; + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, + OrtMemType::OrtMemTypeDefault); +}; + +// Test the ModelEditorAPI C api +// Uses the ORT C++ api for the rest for simplicity +TEST(ModelEditorAPITest, Basic_CApi) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + TestAllocator deleter; + + // return void so we can use ASSERT_* in the lambda + const auto build_model = [&](bool use_constant_node, OrtModel*& model) -> void { + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + // + // Create OrtModel with a Gemm. X input is 3x4, Y input is 4x8, Z output is 3x8. + // X is model input. Y is initializer. + // Set the alpha attribute of the Gemm node to 2.0 to test attribute handling. + // + + // model input + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector input_dims = {3, 4}; + // can use api.SetSymbolicDimensions to set symbolic dimensions. + // the input array should have the same rank as the call to SetDimensions. + // e.g. call SetDimensions with {-1, 3, 2} and SetSymbolicDimensions with {"N", nullptr, nullptr} to create + // a shape of {"N", 3, 2} + + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, input_dims.data(), input_dims.size())); + + OrtTypeInfo* input_type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &input_type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); // input_type_info took a copy + + // create ValueInfo and release the type info as CreateValueInfo takes a copy. + OrtValueInfo* input_value_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("X", input_type_info, &input_value_info)); + api.ReleaseTypeInfo(input_type_info); // input_value_info took a copy + tensor_type_info = nullptr; + + // model outputs + OrtTypeInfo* output_type_info = nullptr; + std::vector output_dims = {3, 8}; + + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, output_dims.data(), output_dims.size())); + + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &output_type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); // input_type_info took a copy + + OrtValueInfo* output_value_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("Z", output_type_info, &output_value_info)); + api.ReleaseTypeInfo(output_type_info); + + std::vector graph_inputs = {input_value_info}; + std::vector graph_outputs = {output_value_info}; + Ort::ThrowOnError(model_editor_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size())); + Ort::ThrowOnError(model_editor_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size())); + input_value_info = nullptr; // graph now owns the input/output values + output_value_info = nullptr; + + // + // Gemm node + // + + OrtOpAttr* alpha_attr = nullptr; + float alpha_value = 2.0; + Ort::ThrowOnError(api.CreateOpAttr("alpha", &alpha_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT, &alpha_attr)); + + std::vector node_input_names = {"X", "Y"}; + const std::string gemm_output_name = use_constant_node ? "Z_temp" : "Z"; + std::vector node_output_names = {gemm_output_name.c_str()}; + std::vector node_attributes{alpha_attr}; + OrtNode* node = CreateNode(model_editor_api, "Gemm", "Gemm1", node_input_names, node_output_names, node_attributes); + alpha_attr = nullptr; // Node now owns + + Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + + // Y input + // As it's 128 bytes it could either be allocated using CreateTensorAsOrtValue or use existing memory. + // Under 128 bytes must use CreateTensorAsOrtValue. + std::vector y_dims = {4, 8}; + + deleter.weights.emplace_back(std::make_unique>(32)); + auto& y_values = *deleter.weights.back(); + std::iota(y_values.begin(), y_values.end(), 1.0f); + + // create an initializer for the Y input. add to `weights` so the memory remains valid. + OrtValue* y_tensor = nullptr; + Ort::ThrowOnError( + api.CreateTensorWithDataAndDeleterAsOrtValue(&deleter, + y_values.data(), y_values.size() * sizeof(y_values[0]), + y_dims.data(), y_dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + &y_tensor)); + + Ort::ThrowOnError(model_editor_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true)); + y_tensor = nullptr; // graph now owns + + if (use_constant_node) { + // Test that a Constant node is converted to an initializer + + // create Constant nodes for min/max to limit output range + OrtOpAttr* min_attr = nullptr; + float min = 400.0f; + Ort::ThrowOnError(api.CreateOpAttr("value", &min, sizeof(min), ORT_OP_ATTR_FLOAT, &min_attr)); + node = CreateNode(model_editor_api, "Constant", "clip_min", {}, {"min"}, {min_attr}); + Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + + OrtOpAttr* max_attr = nullptr; + float max = 900.0f; + Ort::ThrowOnError(api.CreateOpAttr("value", &max, sizeof(max), ORT_OP_ATTR_FLOAT, &max_attr)); + node = CreateNode(model_editor_api, "Constant", "clip_max", {}, {"max"}, {max_attr}); + Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + + node = CreateNode(model_editor_api, "Clip", "Clip1", {gemm_output_name.c_str(), "min", "max"}, {"Z"}); + Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + } + + std::vector domain_names = {onnxruntime::kOnnxDomain}; + std::vector opset_versions = {18}; + Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), domain_names.size(), + &model)); + Ort::ThrowOnError(model_editor_api.AddGraphToModel(model, graph)); + graph = nullptr; // model now owns + }; + + auto run_test = [&](bool use_constant_node) -> void { + OrtModel* model = nullptr; + build_model(use_constant_node, model); + + ASSERT_NE(model, nullptr) << "build_model should have created a model"; + + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = "X"; + input.dims = {3, 4}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, + 8.0f, 7.0f, 6.0f, 5.0f, + 9.0f, 3.0f, 5.0f, 7.0f}; + + std::vector expected_dims = {3, 8}; + Model cxx_model(model); + auto session = CreateSession(*ort_env, cxx_model); + + std::vector expected_output; + if (use_constant_node) { + // clipped with min 400 and max 900 + expected_output = {400.0f, 400.0f, 400.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, + 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 900.0f, 900.0f, + 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 900.0f}; + } else { + expected_output = {340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, + 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f, + 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 928.0f}; + } + + TestInference(session, inputs, "Z", expected_dims, expected_output); + + api.ReleaseSession(session.release()); + + ASSERT_EQ(deleter.weights.size(), size_t(0)) << "All weights should have been freed"; + }; + + run_test(false); + run_test(true); // use Constant node for initializer +} + +TEST(ModelEditorAPITest, Basic_CxxApi) { + // initializers that are used directly by the model. as there's no copy they must remain valid + std::vector>> weights; + + Ort::Graph graph; + + // + // Create OrtModel with a Gemm. X input is 3x4, Y input is 4x8, Z output is 3x8. + // X is model input. Y is initializer. + // Set the alpha attribute of the Gemm node to 2.0 to test attribute handling. + // + + std::vector graph_inputs; + std::vector graph_outputs; + + // model input. it's {3, 4} but use a symbolic dim to test that works. + std::vector input_dims({-1, 4}); + std::vector input_symbolic_dims({"multiple_of_3", ""}); + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + input_dims, + &input_symbolic_dims); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs.emplace_back("X", input_type_info.GetConst()); + + // model outputs + std::vector output_dims = {-1, 8}; + std::vector output_symbolic_dims({"multiple_of_3", ""}); + TensorTypeAndShapeInfo output_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + output_dims, + &output_symbolic_dims); + auto output_type_info = TypeInfo::CreateTensorInfo(output_tensor_info.GetConst()); + graph_outputs.emplace_back("Z", output_type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + // + // Gemm node + // + + std::vector attributes; + float alpha_value = 2.0; + attributes.push_back(OpAttr("alpha", &alpha_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT)); + + Node node("Gemm", onnxruntime::kOnnxDomain, "Gemm1", {"X", "Y"}, {"Z"}, attributes); + + graph.AddNode(node); + + // create an initializer for the Y input. + // add to `weights` so it remains valid for the lifetime of the session and we can avoid copying the data. + // As it's 128 bytes it could either be allocated using CreateTensorAsOrtValue or use existing memory. + // Under 128 bytes must use CreateTensorAsOrtValue. + std::vector y_dims = {4, 8}; + + weights.emplace_back(std::make_unique>(32)); + auto& y_values = *weights.back(); + std::iota(y_values.begin(), y_values.end(), 1.0f); + + auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + // if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession + auto y_tensor = Value::CreateTensor(info, y_values.data(), y_values.size(), y_dims.data(), y_dims.size()); + graph.AddInitializer("Y", y_tensor, /*data is external*/ true); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = "X"; + input.dims = {3, 4}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, + 8.0f, 7.0f, 6.0f, 5.0f, + 9.0f, 3.0f, 5.0f, 7.0f}; + + std::vector expected_dims = {3, 8}; + + auto session = CreateSession(*ort_env, model); + TestInference(session, inputs, "Z", expected_dims, + {340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, + 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f, + 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 928.0f}); +} + +TEST(ModelEditorAPITest, BasicModelEdit_CxxApi) { + // + // Load existing model + // Add Cast to change the model input from float to int64 + // Update model inputs to match + // Run + // + + SessionOptions so; + + // Set this to save the model if you want to debug. + // so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx")); + + Session session = Session::CreateModelEditorSession(*ort_env, TSTR("testdata/mnist.onnx"), so); + + ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string + + // we augment the original model with nodes, initializers and the updated model inputs/outputs from this model. + // the original graph is unchanged. nodes can be added before/after it. initializers can be added. + // new nodes must conform to the original domain:opset of the model. + // additional operator domain:opset pairs can be added. + std::vector opsets; // no additional opsets required + Model model(opsets); + + std::vector graph_inputs = session.GetInputs(); + ASSERT_EQ(graph_inputs.size(), size_t(1)); + ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + // typically this isn't needed. we replace this input but need to read info from it later on in the test + // validation so we save the info locally to keep it accessible. + auto orig_input_name = graph_inputs[0].Name(); + auto input_shape = graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape(); + const std::string new_input_name = "Int64Input"; + + // Add Cast node to convert input from float to int64 + std::vector attributes; + int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT)); + + Ort::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, + // the existing node will now consume the output from the Cast instead of a graph input + {orig_input_name}, + attributes); + + // we're replacing the only input. the shape is the same but the name and data type change. + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + input_shape); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs[0] = ValueInfo(new_input_name, input_type_info.GetConst()); + + Graph graph; // new info to augment the model with + + graph.AddNode(node); + graph.SetInputs(graph_inputs); + + // the node we added does not require any new opsets. + model.AddGraph(graph); + session.FinalizeModelEditorSession(model, so); + + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = new_input_name.c_str(); + input.dims = input_shape; + + auto num_values = std::accumulate(input.dims.begin(), input.dims.end(), int64_t(1), std::multiplies()); + input.values.resize(size_t(num_values)); + std::iota(input.values.begin(), input.values.end(), 1); + + std::vector expected_dims = {1, 10}; + std::vector expected_output = {-48.5088f, -1040.2948f, -347.0959f, 101.7392f, 421.3352f, + 750.92145f, 231.5060f, -1694.4152f, 681.5623f, 378.1689f}; + + TestInference(session, inputs, session.GetOutputNames()[0].c_str(), expected_dims, expected_output); + + // double check with original model + { + SessionOptions expected_so; + Session expected_session = Session(*ort_env, TSTR("testdata/mnist.onnx"), expected_so); + std::vector> expected_inputs(1); + auto& expected_input = expected_inputs[0]; + expected_input.name = orig_input_name.c_str(); + expected_input.dims = input_shape; + expected_input.values.reserve(size_t(num_values)); + std::transform(input.values.begin(), input.values.end(), std::back_inserter(expected_input.values), + [&](int64_t value) { return float(value); }); + + TestInference(expected_session, expected_inputs, session.GetOutputNames()[0].c_str(), + expected_dims, expected_output); + } +} + +TEST(ModelEditorAPITest, InvalidDimension) { + try { + std::vector input_dims = {-2, 2}; + TensorTypeAndShapeInfo tensor_type_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + input_dims); + // invalid dim of -2 should cause exception + TypeInfo::CreateTensorInfo(tensor_type_info.GetConst()); + FAIL() << "Expected exception for invalid dimension"; + } catch (const Ort::Exception& e) { + ASSERT_STREQ(e.what(), "dim_values must be -1 (symbolic dimension) or larger."); + } +} + +TEST(ModelEditorAPITest, CreateInvalidModel_NoOpsets) { + Ort::Graph graph; + std::vector graph_inputs; + std::vector graph_outputs; + + std::vector dims({4}); + TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); + auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); + graph_inputs.emplace_back("X", type_info.GetConst()); + graph_outputs.emplace_back("Z", type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + Ort::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "X"}, {"Z"}); + + graph.AddNode(node); + + std::vector opsets; + Model model(opsets); + model.AddGraph(graph); + + try { + auto session = CreateSession(*ort_env, model); + FAIL(); + } catch (const Ort::Exception& e) { + ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain")); + } +} + +TEST(ModelEditorAPITest, CreateInvalidModel_MissingValue) { + Ort::Graph graph; + + std::vector graph_inputs; + std::vector graph_outputs; + + std::vector dims({4}); + TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); + auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); + graph_inputs.emplace_back("X", type_info.GetConst()); + graph_outputs.emplace_back("Z", type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + Ort::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "missing"}, {"Z"}); + graph.AddNode(node); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + try { + auto session = CreateSession(*ort_env, model); + FAIL(); + } catch (const Ort::Exception& e) { + ASSERT_THAT(e.what(), ::testing::HasSubstr("Node input 'missing' is not a graph input, " + "initializer, or output of a previous node.")); + } +} + +TEST(ModelEditorAPITest, InvalidModelEdit) { + // Add a node but make the edit invalid in various ways + // - add node but don't update graph inputs + // - add node with invalid domain + const auto edit_model = [](bool invalid_domain) { + SessionOptions so; + + // Set this to save the model if you want to debug. + // so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx")); + + Session session = Session::CreateModelEditorSession(*ort_env, TSTR("testdata/mnist.onnx"), so); + + ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string + + std::vector opsets; // no additional opsets required + Model model(opsets); + Graph graph; // new info to augment the model with + + const char* domain = invalid_domain ? "invalid_domain" : onnxruntime::kOnnxDomain; + + std::vector graph_inputs = session.GetInputs(); + ASSERT_EQ(graph_inputs.size(), size_t(1)); + ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + const std::string new_input_name = "Int64Input"; + + // Add Cast node to convert input from float to int64 + std::vector attributes; + int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT)); + + Node node("Cast", domain, "NewInputNode", {new_input_name}, + // the existing node will now consume the output from the Cast instead of a graph input + {graph_inputs[0].Name()}, + attributes); + graph.AddNode(node); + + if (invalid_domain) { + // we're replacing the only input. the shape is the same but the name and data type change. + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape()); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs[0] = ValueInfo(new_input_name, input_type_info.GetConst()); + graph.SetInputs(graph_inputs); + } else { + // model should be invalid as we didn't connect the new node up to the graph inputs + } + + // the node we added does not require any new opsets. + model.AddGraph(graph); + + try { + session.FinalizeModelEditorSession(model, so); + FAIL() << "Should have failed to resolve graph due to invalid edits."; + } catch (const Ort::Exception& e) { + if (invalid_domain) { + ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain 'invalid_domain'")); + } else { + ASSERT_THAT(e.what(), ::testing::HasSubstr("This is an invalid model")); + } + } + }; + + edit_model(false); + edit_model(true); // add node with invalid domain +} + +TEST(ModelEditorAPITest, CreateTypeInfo) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + TensorTypeAndShapeInfo base_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + {2, 4}); + + OrtTypeInfo* base_tensor_type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(base_tensor_info, &base_tensor_type_info)); + + ONNXType onnx_type = ONNX_TYPE_UNKNOWN; + const OrtTensorTypeAndShapeInfo* tensor_info = nullptr; + ONNXTensorElementDataType onnx_element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + // sparse tensor + OrtTypeInfo* sparse_tensor_type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateSparseTensorTypeInfo(base_tensor_info, &sparse_tensor_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sparse_tensor_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_SPARSETENSOR); + Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(sparse_tensor_type_info, &tensor_info)); + Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type)); + ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + api.ReleaseTypeInfo(sparse_tensor_type_info); + + // sequence + OrtTypeInfo* sequence_type_info = nullptr; + const OrtSequenceTypeInfo* sequence_info = nullptr; + OrtTypeInfo* sequence_element_type_info = nullptr; + + Ort::ThrowOnError(model_editor_api.CreateSequenceTypeInfo(base_tensor_type_info, &sequence_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sequence_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_SEQUENCE); + Ort::ThrowOnError(api.CastTypeInfoToSequenceTypeInfo(sequence_type_info, &sequence_info)); + Ort::ThrowOnError(api.GetSequenceElementType(sequence_info, &sequence_element_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sequence_element_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR); + Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(sequence_element_type_info, &tensor_info)); + Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type)); + ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + api.ReleaseTypeInfo(sequence_element_type_info); + api.ReleaseTypeInfo(sequence_type_info); + + // map + OrtTypeInfo* map_type_info = nullptr; + const OrtMapTypeInfo* map_info = nullptr; + ONNXTensorElementDataType map_key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + OrtTypeInfo* map_value_type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateMapTypeInfo(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, base_tensor_type_info, + &map_type_info)); // clones map_type_info + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(map_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_MAP); + Ort::ThrowOnError(api.CastTypeInfoToMapTypeInfo(map_type_info, &map_info)); + Ort::ThrowOnError(api.GetMapKeyType(map_info, &map_key_type)); + ASSERT_EQ(map_key_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); + Ort::ThrowOnError(api.GetMapValueType(map_info, &map_value_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(map_value_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR); + Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(map_value_type_info, &tensor_info)); + Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type)); + ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + api.ReleaseTypeInfo(map_value_type_info); + api.ReleaseTypeInfo(map_type_info); + + // optional + OrtTypeInfo* optional_type_info = nullptr; + const OrtOptionalTypeInfo* optional_info = nullptr; + OrtTypeInfo* optional_contained_type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateOptionalTypeInfo(base_tensor_type_info, &optional_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(optional_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_OPTIONAL); + Ort::ThrowOnError(api.CastTypeInfoToOptionalTypeInfo(optional_type_info, &optional_info)); + Ort::ThrowOnError(api.GetOptionalContainedTypeInfo(optional_info, &optional_contained_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(optional_contained_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR); + api.ReleaseTypeInfo(optional_contained_type_info); + api.ReleaseTypeInfo(optional_type_info); + + api.ReleaseTypeInfo(base_tensor_type_info); +} diff --git a/onnxruntime/test/shared_lib/test_ort_format_models.cc b/onnxruntime/test/shared_lib/test_ort_format_models.cc index 99a9ebc3362ae..b3491e3476f23 100644 --- a/onnxruntime/test/shared_lib/test_ort_format_models.cc +++ b/onnxruntime/test/shared_lib/test_ort_format_models.cc @@ -17,7 +17,7 @@ extern std::unique_ptr ort_env; [[maybe_unused]] static void TestInference(Ort::Env& env, const std::basic_string& model_uri, - const std::vector& inputs, const char* output_name, + const std::vector>& inputs, const char* output_name, const std::vector& expected_dims_y, const std::vector& expected_values_y, Ort::CustomOpDomain& custom_op_domain, void* cuda_compute_stream = nullptr) { Ort::SessionOptions session_options; @@ -100,8 +100,8 @@ TEST(OrtFormatCustomOpTests, ConvertOnnxModelToOrt) { } // now load the ORT format model and execute it - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -130,8 +130,8 @@ TEST(OrtFormatCustomOpTests, LoadOrtModel) { custom_op_domain.Add(&custom_op); // load the ORT format model and execute it - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; @@ -151,8 +151,8 @@ TEST(OrtFormatCustomOpTests, LoadOrtModelStandaloneCustomOpImplementation) { custom_op_domain.Add(&standalone_op); // load the ORT format model and execute it - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; diff --git a/onnxruntime/test/shared_lib/utils.h b/onnxruntime/test/shared_lib/utils.h index 483753f2ae6b2..5d15582b86cb9 100644 --- a/onnxruntime/test/shared_lib/utils.h +++ b/onnxruntime/test/shared_lib/utils.h @@ -5,4 +5,56 @@ #include "core/session/onnxruntime_cxx_api.h" +#include "gtest/gtest.h" + OrtCUDAProviderOptions CreateDefaultOrtCudaProviderOptionsWithCustomStream(void* cuda_compute_stream = nullptr); + +template +struct Input { + const char* name = nullptr; + std::vector dims; + std::vector values; +}; + +template > +void RunSession(OrtAllocator* allocator, + Ort::Session& session_object, + const std::vector& inputs, + const char* output_name, + const std::vector& output_dims, + const std::vector& expected_output, + Ort::Value* output_tensor) { + std::vector ort_inputs; + std::vector input_names; + for (size_t i = 0; i < inputs.size(); i++) { + input_names.emplace_back(inputs[i].name); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), + inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); + } + + std::vector ort_outputs; + if (output_tensor) + session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, output_tensor, 1); + else { + ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, 1); + ASSERT_EQ(ort_outputs.size(), 1u); + output_tensor = &ort_outputs[0]; + } + + auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); + ASSERT_EQ(type_info.GetShape(), output_dims); + size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(expected_output.size(), total_len); + + auto* actual = output_tensor->GetTensorMutableData(); + for (size_t i = 0; i != total_len; ++i) { + if constexpr (std::is_same::value || std::is_same::value) { + EXPECT_NEAR(expected_output[i], actual[i], 1e-3) << "i=" << i; + } else { + EXPECT_EQ(expected_output[i], actual[i]) << "i=" << i; + } + } +} diff --git a/onnxruntime/test/testdata/cast_float_to_double.onnx b/onnxruntime/test/testdata/cast_float_to_double.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dc7997cddd8a8c762e354316662fb0d734e25e86 GIT binary patch literal 136 zcmdfpOwLZtOVKS!EiSPt;8NgX&CDw(EfHeNFD(JmN-WNa#U)ytTudeT65I-kD&v!ZqVaA%{*EE>CHe6#{-I7ju2JGJ&3s%u9E?I7TudCyK+KXP!38x=2qeRe Mka1$+Vh|7o0L&R4`v3p{ literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc index 57471f7c029c2..27a4b06a99e64 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Confidential and Proprietary. +// Licensed under the MIT License. #include "my_execution_provider.h" #include "my_allocator.h" diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h index ff0c7e80c4eeb..efb359a9e5e43 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Confidential and Proprietary. +// Licensed under the MIT License. #pragma once diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 7adfc6a2b2ccb..1ad35b51bb1c1 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -8,6 +8,14 @@ #include "core/session/onnxruntime_cxx_api.h" #include "api.h" +#ifdef USE_WEBGPU +namespace onnxruntime { +namespace webgpu { +WGPUDevice GetDevice(int); +} +} // namespace onnxruntime +#endif + #include #include #include @@ -164,8 +172,12 @@ OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, return UNREGISTER_AUTO_RELEASE(session_options); } -int OrtAppendExecutionProvider(ort_session_options_handle_t session_options, const char* name) { - return CHECK_STATUS(SessionOptionsAppendExecutionProvider, session_options, name, nullptr, nullptr, 0); +int OrtAppendExecutionProvider(ort_session_options_handle_t session_options, + const char* name, + const char* const* provider_options_keys, + const char* const* provider_options_values, + size_t num_keys) { + return CHECK_STATUS(SessionOptionsAppendExecutionProvider, session_options, name, provider_options_keys, provider_options_values, num_keys); } int OrtAddFreeDimensionOverride(ort_session_options_handle_t session_options, @@ -507,6 +519,16 @@ char* OrtEndProfiling(ort_session_handle_t session) { : nullptr; } +// WebGPU API Section + +#ifdef USE_WEBGPU + +WGPUDevice OrtGetWebGpuDevice(int device_id) { + return onnxruntime::webgpu::GetDevice(device_id); +} + +#endif + // Training API Section #ifdef ENABLE_TRAINING_APIS diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index f44c515d98f6b..9ff1eb55ecedc 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -10,6 +10,10 @@ #include +#ifdef USE_WEBGPU +#include +#endif + #include struct OrtSession; @@ -85,7 +89,10 @@ ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtAppendExecutionProvider(ort_session_options_handle_t session_options, - const char* name); + const char* name, + const char* const* provider_options_keys, + const char* const* provider_options_values, + size_t num_keys); /** * add a free dimension override for one dimension of a session's input. @@ -294,6 +301,21 @@ int EMSCRIPTEN_KEEPALIVE OrtRun(ort_session_handle_t session, */ char* EMSCRIPTEN_KEEPALIVE OrtEndProfiling(ort_session_handle_t session); +// WebGPU API Section + +#ifdef USE_WEBGPU + +/** + * get the GPU Device by device ID. + * + * This function is only available after the GPU Device is initialized in WebGpuContextFactory. + * + * @returns a WGPUDevice handle. + */ +WGPUDevice EMSCRIPTEN_KEEPALIVE OrtGetWebGpuDevice(int device_id); + +#endif + // Training API Section #ifdef ENABLE_TRAINING_APIS diff --git a/onnxruntime/wasm/js_post_js.js b/onnxruntime/wasm/js_post_js.js index b77d82fbd7d10..56d3246fd07f0 100644 --- a/onnxruntime/wasm/js_post_js.js +++ b/onnxruntime/wasm/js_post_js.js @@ -2,6 +2,4 @@ // Licensed under the MIT License. -'use strict'; - Module["PTR_SIZE"] = 4; diff --git a/onnxruntime/wasm/js_post_js_64.js b/onnxruntime/wasm/js_post_js_64.js index b140df927ebbd..cfd79523f7900 100644 --- a/onnxruntime/wasm/js_post_js_64.js +++ b/onnxruntime/wasm/js_post_js_64.js @@ -2,6 +2,4 @@ // Licensed under the MIT License. -'use strict'; - Module["PTR_SIZE"] = 8; diff --git a/onnxruntime/wasm/post-webgpu.js b/onnxruntime/wasm/post-webgpu.js new file mode 100644 index 0000000000000..146355f6a44d3 --- /dev/null +++ b/onnxruntime/wasm/post-webgpu.js @@ -0,0 +1,261 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This file contains the post-run code for the ORT WebAssembly module. The code in this file will be injected into the +// final module using Emscripten's `--post-js` option. +// +// This file will only be used in build with flag `--use_webgpu`. + +/** + * This function is called only once when initializing the WebGPU backend. + * + * @param {(gpuDevice: GPUDevice) => void} setDefaultDevice A callback function to set the default device. + */ +Module["webgpuInit"] = (setDefaultDevice) => { + /** + * a map from GPUDevice to [deviceId, instanceHandle, deviceHandle] + * + * only stores custom devices (ie. devices created by the user, not the default device created by ORT) + * + * key is the GPUDevice object. + * + * value is a tuple of 3 elements: + * - deviceId: a unique ID for the device. Must be positive integer. + * - instanceHandle: the instance handle(pointer) of the device. + * - deviceHandle: the device handle(pointer) of the device. + * + * @type {WeakMap} + */ + const webgpuActiveDevices = new WeakMap(); + /** + * a number that is used to assign a unique ID to the next custom device. + */ + let webgpuNextDeviceId = 1; + /** + * a function to set the default device. + * + * @type {(gpuDevice: GPUDevice) => void} + */ + const webgpuSetDefaultDevice = setDefaultDevice; + /** + * the current device that is being used to create a WebGPU EP inference session. + * + * the value of this variable is only valid during the creation of a WebGPU EP inference session. + * + * @type {GPUDevice|undefined} + */ + let webgpuCurrentDevice = undefined; + /** + * the current device ID that is being used to create a WebGPU EP inference session. + * + * the value of this variable is only valid during the creation of a WebGPU EP inference session. + * + * @type {number|undefined} + */ + let webgpuCurrentDeviceId = undefined; + + /** + * This function is called only when a custom device is used, during preparation of session options. + * + * @param {GPUDevice} device the user provided device object. + * @returns {undefined|[number, number, number]} a tuple of device id, instance handle, and device handle. + */ + Module["webgpuRegisterDevice"] = (device) => { + if (webgpuCurrentDeviceId !== undefined) { + throw new Error("another WebGPU EP inference session is being created."); + } + + if (device) { + let deviceInfo = webgpuActiveDevices.get(device); + if (!deviceInfo) { + const instanceHandle = _wgpuCreateInstance(0); + const deviceHandle = WebGPU.importJsDevice(device, instanceHandle); + deviceInfo = [webgpuNextDeviceId++, instanceHandle, deviceHandle]; + webgpuActiveDevices.set(device, deviceInfo); + } + + // The current device ID is a temporary storage for the device ID to be used in the session that is being created. + // + // Soon after `webgpuRegisterDevice` (this function) is called, `webgpuOnCreateSession` will be called so that the + // value of `webgpuCurrentDeviceId` is used and reset then. + webgpuCurrentDevice = device; + webgpuCurrentDeviceId = deviceInfo[0]; + return deviceInfo; + } else { + webgpuCurrentDevice = undefined; + webgpuCurrentDeviceId = 0; + return undefined; + } + }; + + const webgpuActiveSessions = new Map(); + Module["webgpuOnCreateSession"] = (sessionHandle) => { + if (webgpuCurrentDeviceId === undefined) { + // do nothing if webgpuCurrentDeviceId is undefined. + // this means no WebGPU EP is being created. + return; + } + + const deviceId = webgpuCurrentDeviceId; + webgpuCurrentDeviceId = undefined; + + if (sessionHandle) { + // when session created successfully + const deviceHandle = _OrtGetWebGpuDevice(deviceId); + webgpuActiveSessions.set(sessionHandle, deviceHandle); + + if (deviceId === 0) { + const device = webgpuCurrentDevice ?? WebGPU.getJsObject(deviceHandle); + webgpuSetDefaultDevice(device); + } + } + webgpuCurrentDevice = undefined; + }; + + Module["webgpuOnReleaseSession"] = (sessionHandle) => { + webgpuActiveSessions.delete(sessionHandle); + }; + + const gpuBufferMetadataSymbol = Symbol("gpuBufferMetadata"); + + Module["webgpuRegisterBuffer"] = (buffer, sessionHandle, bufferHandle) => { + if (bufferHandle) { + // This is a buffer that was created by ORT. Metadata is [bufferHandle, NaN] + + buffer[gpuBufferMetadataSymbol] = [bufferHandle, NaN]; + return bufferHandle; + } else { + // This is a buffer that was created by the user. Metadata is [bufferHandle, refCount] + + const metadata = buffer[gpuBufferMetadataSymbol]; + if (metadata) { + metadata[1]++; + return metadata[0]; + } + + const deviceHandle = webgpuActiveSessions.get(sessionHandle); + if (deviceHandle === undefined) { + throw new Error( + "Invalid session handle passed to webgpuRegisterBuffer" + ); + } + + const bufferHandle = WebGPU.importJsBuffer(buffer, deviceHandle); + buffer[gpuBufferMetadataSymbol] = [bufferHandle, 1]; + return bufferHandle; + } + }; + + Module["webgpuUnregisterBuffer"] = (buffer) => { + const metadata = buffer[gpuBufferMetadataSymbol]; + if (!metadata) { + throw new Error("Buffer is not registered"); + } + metadata[1]--; + // For buffers created by ORT, metadata[1] will always be NaN. This function will not release the buffer. + // Instead, the buffer will be released when user calls `Tensor.dispose()` in JavaScript. + if (metadata[1] === 0) { + _wgpuBufferRelease(metadata[0]); + delete buffer[gpuBufferMetadataSymbol]; + } + }; + + Module["webgpuGetBuffer"] = (bufferHandle) => { + return WebGPU.getJsObject(bufferHandle); + }; + + Module["webgpuCreateDownloader"] = (gpuBuffer, bufferSize, sessionHandle) => { + const deviceHandle = webgpuActiveSessions.get(sessionHandle); + if (deviceHandle === undefined) { + throw new Error("Invalid session handle passed to webgpuRegisterBuffer"); + } + + const buffer = gpuBuffer; + const device = WebGPU.getJsObject(deviceHandle); + const originalSize = bufferSize; + const size = Math.ceil(Number(originalSize) / 16) * 16; + + return async () => { + // prettier-ignore + // + // the line above is used to force prettier to skip formatting the next statement. + // this is because prettier will remove the quotes around the property names, but we need to keep them + // because otherwise closure compiler may rename them and break the code. + const gpuReadBufferDescriptor = { + "size": size, + "usage": 9 /* GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ */, + }; + const gpuReadBuffer = device.createBuffer(gpuReadBufferDescriptor); + try { + const commandEncoder = device.createCommandEncoder(); + commandEncoder.copyBufferToBuffer( + buffer /* source buffer */, + 0 /* source offset */, + gpuReadBuffer /* destination buffer */, + 0 /* destination offset */, + size /* size */ + ); + device.queue.submit([commandEncoder.finish()]); + + await gpuReadBuffer.mapAsync(GPUMapMode.READ); + + const arrayBuffer = gpuReadBuffer.getMappedRange(); + return arrayBuffer.slice(0, originalSize); + } finally { + gpuReadBuffer.destroy(); + } + }; + }; + + // Setup a callback function for loading external buffers (model weights). + Module.webgpuUploadExternalBuffer = (bufferHandle, data) => { + const srcArrayBuffer = data.buffer; + const srcOffset = data.byteOffset; + const srcLength = data.byteLength; + const size = Math.ceil(Number(srcLength) / 16) * 16; + + const gpuBuffer = WebGPU.getJsObject(bufferHandle); + + // get current device + if (!webgpuCurrentDevice) { + const deviceHandle = _OrtGetWebGpuDevice(webgpuCurrentDeviceId); + webgpuCurrentDevice = WebGPU.getJsObject(deviceHandle); + } + + // create gpu buffer + + // prettier-ignore + // + // the line above is used to force prettier to skip formatting the next statement. + // this is because prettier will remove the quotes around the property names, but we need to keep them + // because otherwise closure compiler may rename them and break the code. + const gpuBufferForUploadingDescriptor = { + "mappedAtCreation": true, + "size": size, + "usage": 6 /* GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC */, + }; + const gpuBufferForUploading = webgpuCurrentDevice.createBuffer( + gpuBufferForUploadingDescriptor + ); + + // copy (upload) data + const arrayBuffer = gpuBufferForUploading.getMappedRange(); + new Uint8Array(arrayBuffer).set( + new Uint8Array(srcArrayBuffer, srcOffset, srcLength) + ); + gpuBufferForUploading.unmap(); + + // GPU copy + const commandEncoder = webgpuCurrentDevice.createCommandEncoder(); + commandEncoder.copyBufferToBuffer( + gpuBufferForUploading, + 0, + gpuBuffer, + 0, + size + ); + webgpuCurrentDevice.queue.submit([commandEncoder.finish()]); + gpuBufferForUploading.destroy(); + }; +}; diff --git a/onnxruntime/wasm/pre-async.js b/onnxruntime/wasm/pre-async.js new file mode 100644 index 0000000000000..8c75dc7c5cf1e --- /dev/null +++ b/onnxruntime/wasm/pre-async.js @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the +// final module using Emscripten's `--pre-js` option. +// +// This file will only be used in build with flag `-s ASYNCIFY=1`. + +/** + * initialize for asyncify support. + */ +let initAsyncImpl = () => { + // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) + // It removes some overhead in cwarp() and ccall() that we don't need. + // + // Currently in ASYNCIFY build, we only use this for the following functions: + // - OrtCreateSession() + // - OrtRun() + // - OrtRunWithBinding() + // - OrtBindInput() + // + // Note: about parameters "getFunc" and "setFunc": + // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. + // + // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a + // wrapper for OrtRun() like this (minified): + // ``` + // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); + // ``` + // + // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates + // a wrapper for OrtRun() like this (minified): + // ``` + // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // + // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once + // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will + // reset d._OrtRun to J.ka when the first time it is called. + // + // The difference is important because we need to design the async wrapper in a way that it can handle both cases. + // + // Now, let's look at how the async wrapper is designed to work for both cases: + // + // - Debug build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. + // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // - Release build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. + // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this + // function: + // ``` + // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). + // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored + // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release + // build. + // + // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an + // exported function and set the new value of an exported function. + // + const wrapAsync = (func, getFunc, setFunc) => { + return (...args) => { + // cache the async data before calling the function. + const previousAsync = Asyncify.currData; + + const previousFunc = getFunc?.(); + const ret = func(...args); + const newFunc = getFunc?.(); + if (previousFunc !== newFunc) { + // The exported function has been updated. + // Set the sync function reference to the new function. + func = newFunc; + // Set the exported function back to the async wrapper. + setFunc(previousFunc); + // Remove getFunc and setFunc. They are no longer needed. + setFunc = null; + getFunc = null; + } + + // If the async data has been changed, it means that the function started an async operation. + if (Asyncify.currData != previousAsync) { + // returns the promise + return Asyncify.whenDone(); + } + // the function is synchronous. returns the result. + return ret; + }; + }; + + // replace the original functions with asyncified versions + const wrapAsyncAPIs = (funcNames) => { + for (const funcName of funcNames) { + Module[funcName] = wrapAsync( + Module[funcName], + () => Module[funcName], + (v) => (Module[funcName] = v) + ); + } + }; + + wrapAsyncAPIs([ + "_OrtAppendExecutionProvider", + "_OrtCreateSession", + "_OrtRun", + "_OrtRunWithBinding", + "_OrtBindInput", + ]); + + // If JSEP is enabled, wrap OrtRun() and OrtRunWithBinding() with asyncify. + if (typeof jsepRunAsync !== "undefined") { + Module["_OrtRun"] = jsepRunAsync(Module["_OrtRun"]); + Module["_OrtRunWithBinding"] = jsepRunAsync(Module["_OrtRunWithBinding"]); + } + + // remove this function to make sure it is called only once. + initAsyncImpl = undefined; +}; + +Module["asyncInit"] = () => { + initAsyncImpl?.(); +}; diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 0c83e71a921cb..5b2f044d4c27b 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -1,255 +1,157 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -'use strict'; - // // This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the // final module using Emscripten's `--pre-js` option. // // This file will only be used in build with flag `--use_jsep`. - -/** - * initialize JSEP for asyncify support. - */ -let jsepInitAsync = () => { - // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) - // It removes some overhead in cwarp() and ccall() that we don't need. - // - // Currently in JSEP build, we only use this for the following functions: - // - OrtRun() - // - OrtRunWithBinding() - // - OrtBindInput() - // - // Note: about parameters "getFunc" and "setFunc": - // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. - // - // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a - // wrapper for OrtRun() like this (minified): - // ``` - // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); - // ``` - // - // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates - // a wrapper for OrtRun() like this (minified): - // ``` - // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // - // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once - // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will - // reset d._OrtRun to J.ka when the first time it is called. - // - // The difference is important because we need to design the async wrapper in a way that it can handle both cases. - // - // Now, let's look at how the async wrapper is designed to work for both cases: - // - // - Debug build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. - // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // - Release build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. - // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this - // function: - // ``` - // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). - // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored - // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release - // build. - // - // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an - // exported function and set the new value of an exported function. - // - const jsepWrapAsync = (func, getFunc, setFunc) => { - return (...args) => { - // cache the async data before calling the function. - const previousAsync = Asyncify.currData; - - const previousFunc = getFunc?.(); - const ret = func(...args); - const newFunc = getFunc?.(); - if (previousFunc !== newFunc) { - // The exported function has been updated. - // Set the sync function reference to the new function. - func = newFunc; - // Set the exported function back to the async wrapper. - setFunc(previousFunc); - // Remove getFunc and setFunc. They are no longer needed. - setFunc = null; - getFunc = null; +// This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly. +const jsepRunAsync = (runAsyncFunc) => { + return async (...args) => { + try { + // Module.jsepSessionState should be null, unless we are in the middle of a session. + // If it is not null, it means that the previous session has not finished yet. + if (Module.jsepSessionState) { + throw new Error("Session already started"); } + const state = (Module.jsepSessionState = { + sessionHandle: args[0], + errors: [], + }); - // If the async data has been changed, it means that the function started an async operation. - if (Asyncify.currData != previousAsync) { - // returns the promise - return Asyncify.whenDone(); - } - // the function is synchronous. returns the result. - return ret; - }; - }; - - // This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly. - const runAsync = (runAsyncFunc) => { - return async (...args) => { - try { - // Module.jsepSessionState should be null, unless we are in the middle of a session. - // If it is not null, it means that the previous session has not finished yet. - if (Module.jsepSessionState) { - throw new Error('Session already started'); - } - const state = Module.jsepSessionState = {sessionHandle: args[0], errors: []}; - - // Run the acyncified function: OrtRun() or OrtRunWithBinding() - const ret = await runAsyncFunc(...args); + // Run the acyncified function: OrtRun() or OrtRunWithBinding() + const ret = await runAsyncFunc(...args); - // Check if the session is still valid. this object should be the same as the one we set above. - if (Module.jsepSessionState !== state) { - throw new Error('Session mismatch'); - } + // Check if the session is still valid. this object should be the same as the one we set above. + if (Module.jsepSessionState !== state) { + throw new Error("Session mismatch"); + } - // Flush the backend. This will submit all pending commands to the GPU. - Module.jsepBackend?.['flush'](); + // Flush the backend. This will submit all pending commands to the GPU. + Module.jsepBackend?.["flush"](); - // Await all pending promises. This includes GPU validation promises for diagnostic purposes. - const errorPromises = state.errors; - if (errorPromises.length > 0) { - let errors = await Promise.all(errorPromises); - errors = errors.filter(e => e); - if (errors.length > 0) { - throw new Error(errors.join('\n')); - } + // Await all pending promises. This includes GPU validation promises for diagnostic purposes. + const errorPromises = state.errors; + if (errorPromises.length > 0) { + let errors = await Promise.all(errorPromises); + errors = errors.filter((e) => e); + if (errors.length > 0) { + throw new Error(errors.join("\n")); } - - return ret; - } finally { - Module.jsepSessionState = null; } - }; - }; - // replace the original functions with asyncified versions - Module['_OrtCreateSession'] = jsepWrapAsync( - Module['_OrtCreateSession'], - () => Module['_OrtCreateSession'], - v => Module['_OrtCreateSession'] = v); - Module['_OrtRun'] = runAsync(jsepWrapAsync( - Module['_OrtRun'], - () => Module['_OrtRun'], - v => Module['_OrtRun'] = v)); - Module['_OrtRunWithBinding'] = runAsync(jsepWrapAsync( - Module['_OrtRunWithBinding'], - () => Module['_OrtRunWithBinding'], - v => Module['_OrtRunWithBinding'] = v)); - Module['_OrtBindInput'] = jsepWrapAsync( - Module['_OrtBindInput'], - () => Module['_OrtBindInput'], - v => Module['_OrtBindInput'] = v); - - // remove this function to make sure it is called only once. - jsepInitAsync = undefined; + return ret; + } finally { + Module.jsepSessionState = null; + } + }; }; - /** - * initialize JSEP for WebGPU. + * initialize JSEP for WebGPU and WebNN. */ -Module['jsepInit'] = (name, params) => { - jsepInitAsync?.(); - - if (name === 'webgpu') { - [Module.jsepBackend, - Module.jsepAlloc, - Module.jsepFree, - Module.jsepCopy, - Module.jsepCopyAsync, - Module.jsepCreateKernel, - Module.jsepReleaseKernel, - Module.jsepRunKernel, - Module.jsepCaptureBegin, - Module.jsepCaptureEnd, - Module.jsepReplay] = params; +Module["jsepInit"] = (name, params) => { + if (name === "webgpu") { + [ + Module.jsepBackend, + Module.jsepAlloc, + Module.jsepFree, + Module.jsepCopy, + Module.jsepCopyAsync, + Module.jsepCreateKernel, + Module.jsepReleaseKernel, + Module.jsepRunKernel, + Module.jsepCaptureBegin, + Module.jsepCaptureEnd, + Module.jsepReplay, + ] = params; // expose webgpu backend functions const backend = Module.jsepBackend; - Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { - return backend['registerBuffer'](sessionId, index, buffer, size); + Module["jsepRegisterBuffer"] = (sessionId, index, buffer, size) => { + return backend["registerBuffer"](sessionId, index, buffer, size); }; - Module['jsepGetBuffer'] = (dataId) => { - return backend['getBuffer'](dataId); + Module["jsepGetBuffer"] = (dataId) => { + return backend["getBuffer"](dataId); }; - Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { - return backend['createDownloader'](gpuBuffer, size, type); + Module["jsepCreateDownloader"] = (gpuBuffer, size, type) => { + return backend["createDownloader"](gpuBuffer, size, type); }; - Module['jsepOnCreateSession'] = sessionId => { - backend['onCreateSession'](sessionId); + Module["jsepOnCreateSession"] = (sessionId) => { + backend["onCreateSession"](sessionId); }; - Module['jsepOnReleaseSession'] = sessionId => { - backend['onReleaseSession'](sessionId); + Module["jsepOnReleaseSession"] = (sessionId) => { + backend["onReleaseSession"](sessionId); }; - Module['jsepOnRunStart'] = sessionId => { - return backend['onRunStart'](sessionId); + Module["jsepOnRunStart"] = (sessionId) => { + return backend["onRunStart"](sessionId); }; Module.jsepUploadExternalBuffer = (dataId, buffer) => { - backend['upload'](dataId, buffer); + backend["upload"](dataId, buffer); }; - } else if (name === 'webnn') { + } else if (name === "webnn") { // Functions called from EM_ASM need to be assigned in a way that can be minified. // Functions called via emscripten::val::module_property need to be assigned by name so that the minifier doesn't // change the name. - [Module.jsepBackend, - Module.jsepReserveTensorId, - Module.jsepReleaseTensorId, - Module['jsepEnsureTensor'], - Module.jsepUploadTensor, - Module['jsepDownloadTensor'], + [ + Module.jsepBackend, + Module.jsepReserveTensorId, + Module.jsepReleaseTensorId, + Module["jsepEnsureTensor"], + Module.jsepUploadTensor, + Module["jsepDownloadTensor"], ] = params; // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. - Module['jsepReleaseTensorId'] = Module.jsepReleaseTensorId; - Module['jsepUploadTensor'] = Module.jsepUploadTensor; + Module["jsepReleaseTensorId"] = Module.jsepReleaseTensorId; + Module["jsepUploadTensor"] = Module.jsepUploadTensor; // Functions called from JS also need to have explicit names. const backend = Module.jsepBackend; - Module['jsepOnRunStart'] = sessionId => { - return backend['onRunStart'](sessionId); + Module["jsepOnRunStart"] = (sessionId) => { + return backend["onRunStart"](sessionId); }; - Module['jsepOnRunEnd'] = backend['onRunEnd'].bind(backend); - Module['jsepRegisterMLContext'] = (sessionId, mlContext) => { - backend['registerMLContext'](sessionId, mlContext); + Module["jsepOnRunEnd"] = backend["onRunEnd"].bind(backend); + Module["jsepRegisterMLContext"] = (sessionId, mlContext) => { + backend["registerMLContext"](sessionId, mlContext); }; - Module['jsepOnReleaseSession'] = sessionId => { - backend['onReleaseSession'](sessionId); + Module["jsepOnReleaseSession"] = (sessionId) => { + backend["onReleaseSession"](sessionId); }; - Module['jsepCreateMLTensorDownloader'] = (tensorId, type) => { - return backend['createMLTensorDownloader'](tensorId, type); - } - Module['jsepRegisterMLTensor'] = (sessionId, tensor, dataType, shape) => { - return backend['registerMLTensor'](sessionId, tensor, dataType, shape); + Module["jsepCreateMLTensorDownloader"] = (tensorId, type) => { + return backend["createMLTensorDownloader"](tensorId, type); + }; + Module["jsepRegisterMLTensor"] = (sessionId, tensor, dataType, shape) => { + return backend["registerMLTensor"](sessionId, tensor, dataType, shape); }; - Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { - return backend['createMLContext'](optionsOrGpuDevice); + Module["jsepCreateMLContext"] = (optionsOrGpuDevice) => { + return backend["createMLContext"](optionsOrGpuDevice); }; - Module['jsepRegisterMLConstant'] = (externalFilePath, dataOffset, dataLength, builder, desc) => { - return backend['registerMLConstant']( - externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); + Module["jsepRegisterMLConstant"] = ( + externalFilePath, + dataOffset, + dataLength, + builder, + desc + ) => { + return backend["registerMLConstant"]( + externalFilePath, + dataOffset, + dataLength, + builder, + desc, + Module.MountedFiles + ); }; - Module['jsepRegisterGraphInput'] = backend['registerGraphInput'].bind(backend); - Module['jsepIsGraphInput'] = backend['isGraphInput'].bind(backend); + Module["jsepRegisterGraphInput"] = + backend["registerGraphInput"].bind(backend); + Module["jsepIsGraphInput"] = backend["isGraphInput"].bind(backend); - Module['jsepCreateTemporaryTensor'] = backend['createTemporaryTensor'].bind(backend); + Module["jsepCreateTemporaryTensor"] = + backend["createTemporaryTensor"].bind(backend); } }; diff --git a/onnxruntime/wasm/pre.js b/onnxruntime/wasm/pre.js index 9b5f3ce545b78..636a9713519a7 100644 --- a/onnxruntime/wasm/pre.js +++ b/onnxruntime/wasm/pre.js @@ -1,21 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -'use strict'; - // // This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the // final module using Emscripten's `--pre-js` option. - /** * Mount external data files of a model to an internal map, which will be used during session initialization. * * @param {string} externalDataFilesPath * @param {Uint8Array} externalDataFilesData */ -Module['mountExternalData'] = (externalDataFilePath, externalDataFileData) => { - if (externalDataFilePath.startsWith('./')) { +Module["mountExternalData"] = (externalDataFilePath, externalDataFileData) => { + if (externalDataFilePath.startsWith("./")) { externalDataFilePath = externalDataFilePath.substring(2); } const files = Module.MountedFiles || (Module.MountedFiles = new Map()); @@ -25,7 +22,7 @@ Module['mountExternalData'] = (externalDataFilePath, externalDataFileData) => { /** * Unmount external data files of a model. */ -Module['unmountExternalData'] = () => { +Module["unmountExternalData"] = () => { delete Module.MountedFiles; }; @@ -48,5 +45,7 @@ Module['unmountExternalData'] = () => { * * @suppress {checkVars} */ -var SharedArrayBuffer = globalThis.SharedArrayBuffer ?? - new WebAssembly.Memory({'initial': 0, 'maximum': 0, 'shared': true}).buffer.constructor; +var SharedArrayBuffer = + globalThis.SharedArrayBuffer ?? + new WebAssembly.Memory({ initial: 0, maximum: 0, shared: true }).buffer + .constructor; diff --git a/setup.py b/setup.py index ced2f28e38778..53e533050b245 100644 --- a/setup.py +++ b/setup.py @@ -356,7 +356,7 @@ def finalize_options(self): "libQnnSaver.so", "libQnnSystem.so", "libHtpPrepare.so", - "onnxruntime_qnn_ctx_gen", + "ep_weight_sharing_ctx_gen", ] dl_libs.extend(qnn_deps) if nightly_build: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 8607887072347..db7dbed23a2d2 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -35,7 +35,8 @@ def version_to_tuple(version: str) -> tuple: import util.android as android # noqa: E402 from util import ( # noqa: E402 generate_android_triplets, - generate_posix_triplets, + generate_linux_triplets, + generate_macos_triplets, generate_vcpkg_triplets_for_emscripten, generate_windows_triplets, get_logger, @@ -1115,7 +1116,6 @@ def generate_build_tree( cmake_extra_args, ): log.info("Generating CMake build tree") - cmake_dir = os.path.join(source_dir, "cmake") cmake_args = [cmake_path, cmake_dir] if not use_dev_mode(args): @@ -1330,8 +1330,16 @@ def generate_build_tree( generate_android_triplets(build_dir, args.android_cpp_shared, args.android_api) elif is_windows(): generate_windows_triplets(build_dir) + elif is_macOS(): + osx_target = args.apple_deploy_target + if args.apple_deploy_target is None: + osx_target = os.environ.get("MACOSX_DEPLOYMENT_TARGET") + if osx_target is not None: + log.info(f"Setting VCPKG_OSX_DEPLOYMENT_TARGET to {osx_target}") + generate_macos_triplets(build_dir, osx_target) else: - generate_posix_triplets(build_dir) + # Linux, *BSD, AIX or other platforms + generate_linux_triplets(build_dir) add_default_definition(cmake_extra_defines, "CMAKE_TOOLCHAIN_FILE", str(vcpkg_toolchain_path)) vcpkg_install_options = generate_vcpkg_install_options(build_dir, args) @@ -1592,8 +1600,11 @@ def generate_build_tree( raise BuildError("WebNN is only available for WebAssembly build.") cmake_args += ["-Donnxruntime_USE_WEBNN=ON"] - if args.use_jsep and args.use_webgpu: - raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") + # TODO: currently we allows building with both --use_jsep and --use_webgpu in this working branch. + # This situation is temporary. Eventually, those two flags will be mutually exclusive. + # + # if args.use_jsep and args.use_webgpu: + # raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") if args.use_external_dawn and not args.use_webgpu: raise BuildError("External Dawn (--use_external_dawn) must be enabled with WebGPU (--use_webgpu).") diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml new file mode 100644 index 0000000000000..8aaaa0e85585a --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -0,0 +1,142 @@ +parameters: +- name: CudaVersion + type: string + default: '12.2' + +- name: QnnSdk + displayName: QNN SDK Version + type: string + default: 2.31.0.250130 + +- name: IsReleaseBuild + displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. + type: boolean + default: false + +- name: PackageName + displayName: What is the package name? + type: string + default: 'Microsoft.ML.OnnxRuntime.Flamingo' + +variables: + - template: templates/common-variables.yml + - name: ReleaseVersionSuffix + value: '' + - name: win_cuda_home + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: $(Agent.TempDirectory)\v11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: $(Agent.TempDirectory)\v12.2 + +stages: + - template: templates/win-ci.yml + parameters: + ort_build_pool_name: 'onnxruntime-Win2022-GPU-A10' + DoCompliance: false + DoEsrp: true + stage_name_suffix: CUDA + buildArch: x64 + msbuildPlatform: x64 + packageName: x64-cuda + CudaVersion: ${{ parameters.CudaVersion }} + buildparameter: --use_cuda --cuda_home=${{ variables.win_cuda_home }} --enable_onnx_tests --enable_wcos --use_webgpu --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual" + runTests: false + buildJava: false + java_artifact_id: onnxruntime_gpu + UseIncreasedTimeoutForTests: false + SpecificArtifact: false + BuildId: '0' + + - template: templates/qnn-ep-win.yml + parameters: + qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QnnSdk: ${{ parameters.QnnSdk }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + DoEsrp: true + ArtifactName: 'drop-nuget-qnn-arm64' + # Add --use_webgpu to enable WebGPU + buildParameter: '--arm64' + buildPlatform: 'ARM64' + buildArch: 'ARM64' + StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' + build_config: 'RelWithDebInfo' + Is1ES: false + PublishArchive: true + + - stage: NugetPackaging + dependsOn: [Windows_Packaging_CUDA, OnnxRuntime_QNN_Nuget_Win_Arm64] + jobs: + - job: CreateNugetPackage + pool: 'Onnxruntime-Win2022-GPU-A10' + timeoutInMinutes: 120 + steps: + - checkout: self + clean: true + submodules: none + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - managed nuget' + inputs: + artifactName: 'drop-nuget-qnn-arm64' + targetPath: '$(Build.BinariesDirectory)/managed-nuget' + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - win-x64' + inputs: + artifactName: 'onnxruntime-win-x64-cuda' + targetPath: '$(Build.BinariesDirectory)/win-x64' + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - win-arm64' + inputs: + artifactName: 'onnxruntime-win-ARM64-qnn' + targetPath: '$(Build.BinariesDirectory)/win-arm64' + + - task: PowerShell@2 + displayName: 'Extract Nuget Package Version' + inputs: + targetType: 'inline' + script: | + $nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/managed-nuget -Filter Microsoft.ML.OnnxRuntime.Managed.*.nupkg -Recurse) + $package_name = $nupkgs[0].Name + $version_length = $package_name.Length - "Microsoft.ML.OnnxRuntime.Managed.".Length - ".nupkg".Length + $package_version = $package_name.Substring("Microsoft.ML.OnnxRuntime.Managed.".Length, $version_length) + Write-Host "##vso[task.setvariable variable=package_version;]$package_version" + workingDirectory: $(Build.BinariesDirectory) + + - task: PowerShell@2 + displayName: 'Extract Archives' + inputs: + targetType: 'inline' + script: | + Expand-Archive -Path $(Build.BinariesDirectory)/win-x64/onnxruntime-win-x64-cuda*.zip -DestinationPath $(Build.BinariesDirectory)/win-x64 + Expand-Archive -Path $(Build.BinariesDirectory)/win-arm64/onnxruntime-win-ARM64-qnn*.zip -DestinationPath $(Build.BinariesDirectory)/win-arm64 + $win_x64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-x64 -Filter onnxruntime-win-x64-cuda*)[0].FullName + $win_arm64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-arm64 -Filter onnxruntime-win-ARM64-qnn*)[0].FullName + Write-Host "##vso[task.setvariable variable=win_x64;]$win_x64" + Write-Host "##vso[task.setvariable variable=win_arm64;]$win_arm64" + workingDirectory: $(Build.BinariesDirectory) + + - task: PythonScript@0 + displayName: 'Generate Nuget Package' + inputs: + scriptPath: '$(Build.SourcesDirectory)/tools/nuget/generate_nuspec_for_custom_nuget.py' + arguments: '--nuspec_path "$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec" --root_dir "$(Build.SourcesDirectory)" --commit_id "$(Build.SourceVersion)" --win_arm64 "$(win_arm64)" --win_x64 "$(win_x64)" --package_version "$(package_version)" --package_name "${{ parameters.PackageName }}"' + + - task: NuGetCommand@2 + displayName: 'Pack Nuget Package' + inputs: + command: 'pack' + packagesToPack: '$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec' + packDestination: $(Build.ArtifactStagingDirectory)\ + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: Nuget' + inputs: + pathtoPublish: '$(Build.ArtifactStagingDirectory)' + artifactName: '${{ parameters.PackageName }}' diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index a0e49692220f9..7a78c6ba0fcdf 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -31,10 +31,12 @@ stages: machine_pool: vmImage: 'macOS-13' itemPattern: '*/*mac*x86_64.whl' + arch: 'x86_64' - template: templates/py-package-smoking-test.yml parameters: job_name: Test_LINUX_x86_64_Wheels itemPattern: '*/*manylinux*x86_64.whl' + arch: 'x86_64' machine_pool: name: 'onnxruntime-Ubuntu2204-AMD-CPU' diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 01d30d0e1ba86..28ddd29ec63e6 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -50,10 +50,10 @@ parameters: displayName: 'Linux packages cmake build type. Linux Only.' default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel # Only applies to QNN packages. - name: qnn_sdk_version @@ -63,17 +63,33 @@ parameters: trigger: none -stages: -- template: stages/py-cpu-packaging-stage.yml +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines parameters: - enable_linux_cpu: ${{ parameters.enable_linux_cpu }} - enable_windows_cpu: ${{ parameters.enable_windows_cpu }} - enable_mac_cpu: ${{ parameters.enable_mac_cpu }} - enable_linux_arm: ${{ parameters.enable_linux_arm }} - enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} - enable_windows_arm64ec_qnn: ${{ parameters.enable_windows_arm64ec_qnn }} - enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }} - enable_linux_x64_qnn: ${{ parameters.enable_linux_x64_qnn }} - build_py_parameters: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - qnn_sdk_version: ${{ parameters.qnn_sdk_version }} + sdl: + sourceAnalysisPool: + name: onnxruntime-Win-CPU-2022 + os: windows + stages: + - template: stages/py-cpu-packaging-stage.yml + parameters: + enable_linux_cpu: ${{ parameters.enable_linux_cpu }} + enable_windows_cpu: ${{ parameters.enable_windows_cpu }} + enable_mac_cpu: ${{ parameters.enable_mac_cpu }} + enable_linux_arm: ${{ parameters.enable_linux_arm }} + enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} + enable_windows_arm64ec_qnn: ${{ parameters.enable_windows_arm64ec_qnn }} + enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }} + enable_linux_x64_qnn: ${{ parameters.enable_linux_x64_qnn }} + build_py_parameters: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + qnn_sdk_version: ${{ parameters.qnn_sdk_version }} diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 055ef58e4524a..cfca998e0f06c 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -29,108 +29,58 @@ parameters: displayName: Pipeline BuildId, you could find it in the URL type: string default: '0' - -stages: - -- template: templates/qnn-ep-win.yml - parameters: - qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QnnSdk: ${{ parameters.QnnSdk }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-x64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' - build_config: ${{ parameters.build_config }} - -- template: templates/qnn-ep-win.yml +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines parameters: - qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QnnSdk: ${{ parameters.QnnSdk }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-arm64' - buildParameter: '--arm64' - buildPlatform: 'ARM64' - buildArch: 'ARM64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' - build_config: ${{ parameters.build_config }} - -- stage: NuGet_Packaging_QNN - pool: 'Onnxruntime-QNNEP-Windows-2022-CPU' - dependsOn: - - OnnxRuntime_QNN_Nuget_Win_x64 - - OnnxRuntime_QNN_Nuget_Win_Arm64 - condition: succeeded() - jobs: - - job: NuGet_Packaging_QNN - workspace: - clean: all - steps: - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - QNN NuGet x64' - inputs: - artifactName: 'drop-nuget-qnn-x64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-x64' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - QNN NuGet arm64' - inputs: - artifactName: 'drop-nuget-qnn-arm64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64' - - - task: PowerShell@2 - displayName: 'Bundle NuGet' - inputs: - targetType: 'inline' - script: | - - $x64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-x64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) - $nuget_package_name = $x64_nupkgs[0].Name - $x64_nuget_package = $x64_nupkgs[0].FullName - - $nupkg_unzipped_directory = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget_unzip_merged', [System.IO.Path]::GetFileNameWithoutExtension($nuget_package_name)) - - $x64_unzip_cmd = "7z.exe x $x64_nuget_package -y -o$nupkg_unzipped_directory" - Invoke-Expression -Command $x64_unzip_cmd - - $arm64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-arm64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) - $arm64_nuget_package = $arm64_nupkgs[0].FullName + sdl: + sourceAnalysisPool: + name: onnxruntime-Win-CPU-2022 + os: windows + stages: - $arm64_unzip_cmd = "7z.exe x $arm64_nuget_package -y -o$nupkg_unzipped_directory" - Invoke-Expression -Command $arm64_unzip_cmd - - $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget-artifact-merged') - if (!(Test-Path $merged_nuget_path)) { - New-Item -Path $merged_nuget_path -ItemType Directory - } - - $merged_zip = [System.IO.Path]::Combine($merged_nuget_path, 'qnn_nuget.zip') - $zip_cmd = "7z.exe a -r $merged_zip $nupkg_unzipped_directory/*" - Invoke-Expression -Command $zip_cmd - - $merged_nuget = [System.IO.Path]::Combine($merged_nuget_path, $nuget_package_name) - move $merged_zip $merged_nuget - workingDirectory: $(Build.BinariesDirectory) - - - template: templates/esrp_nuget.yml + - template: templates/qnn-ep-win.yml parameters: - DisplayName: 'ESRP - sign NuGet package' - FolderPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' + qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QnnSdk: ${{ parameters.QnnSdk }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} DoEsrp: ${{ parameters.DoEsrp }} + ArtifactName: 'drop-nuget-qnn-x64' + StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' + build_config: ${{ parameters.build_config }} - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-qnn' - targetPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' + - template: templates/qnn-ep-win.yml + parameters: + qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QnnSdk: ${{ parameters.QnnSdk }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + DoEsrp: ${{ parameters.DoEsrp }} + ArtifactName: 'drop-nuget-qnn-arm64' + buildParameter: '--arm64' + buildPlatform: 'ARM64' + buildArch: 'ARM64' + StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' + build_config: ${{ parameters.build_config }} + + - template: stages/nuget-qnn-packaging-stage.yml + parameters: + DoEsrp: ${{ parameters.DoEsrp }} -- template: templates/publish-nuget-steps.yml - parameters: - download_artifacts_steps: - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' - ArtifactName: 'drop-signed-nuget-qnn' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} + - template: templates/publish-nuget-steps.yml + parameters: + download_artifacts_steps: + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' + ArtifactName: 'drop-signed-nuget-qnn' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml deleted file mode 100644 index f7f5c7b1494e8..0000000000000 --- a/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml +++ /dev/null @@ -1,339 +0,0 @@ -parameters: -- name: RunOnnxRuntimeTests - displayName: Run Tests? - type: boolean - default: true - -- name: UseIncreasedTimeoutForTests - displayName: Increase timeout for tests? Set it to false if you are doing an Onnx Runtime release. - type: boolean - default: false - -- name: DoCompliance - displayName: Run Compliance Tasks? - type: boolean - default: true - -- name: DoEsrp - displayName: Run code sign tasks? Must be true if you are doing an ONNX Runtime release - type: boolean - default: true - -- name: IsReleaseBuild - displayName: Is a release build? Set it to true if you are doing an ONNX Runtime release. - type: boolean - default: false - -- name: PreReleaseVersionSuffixString - displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. - type: string - values: - - alpha - - beta - - rc - - none - default: none - -- name: PreReleaseVersionSuffixNumber - displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. - type: number - default: 0 - -# these 2 parameters are used for debugging. -- name: SpecificArtifact - displayName: Use Specific Artifact (Debugging only) - type: boolean - default: false - -- name: BuildId - displayName: Pipeline BuildId, you could find it in the URL - type: string - default: '0' - -- name: NugetPackageSuffix - displayName: Suffix to append to nuget package - type: string - default: 'NONE' - -resources: - repositories: - - repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step - type: github - endpoint: ort-examples - name: microsoft/onnxruntime-inference-examples - - repository: manylinux - type: Github - endpoint: Microsoft - name: pypa/manylinux - ref: 5eda9aded5462201e6310105728d33016e637ea7 - -variables: -- name: ReleaseVersionSuffix - value: '' - -stages: -- template: stages/set_packaging_variables_stage.yml - parameters: - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} - PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - -# ROCm -- stage: Linux_C_API_Packaging_ROCm_x64 - dependsOn: [] - jobs: - - job: Linux_C_API_Packaging_ROCm_x64 - workspace: - clean: all - timeoutInMinutes: 480 - pool: onnxruntime-Ubuntu2204-AMD-CPU - variables: - RocmVersion: '6.2' - RocmVersionPatchSuffix: '' - steps: - - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - submodules: recursive - - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux, for get-docker-image-steps.yml - submodules: false - - # get-docker-image-steps.yml will move the $(Build.SourcesDirectory)/manylinux into $(Build.SourcesDirectory)/onnxruntime, - # then rename $(Build.SourcesDirectory)/onnxruntime as $(Build.SourcesDirectory) - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: >- - --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur - --build-arg BUILD_UID=$(id -u) - --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 - --build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix) - --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root - --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: - --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib - Repository: onnxruntimetrainingrocmbuild-rocm$(RocmVersion) - CheckOutManyLinux: true - - - template: templates/set-version-number-variables-step.yml - - - task: Bash@3 - displayName: 'Build' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/build_rocm_c_api_package.sh - arguments: >- - -S $(Build.SourcesDirectory) - -B $(Build.BinariesDirectory) - -V $(RocmVersion) - -I onnxruntimetrainingrocmbuild-rocm$(RocmVersion) - -P python3.10 - - - script: | - set -e -x - mkdir $(Build.ArtifactStagingDirectory)/testdata - cp $(Build.BinariesDirectory)/Release/libcustom_op_library.so* $(Build.ArtifactStagingDirectory)/testdata - ls -al $(Build.ArtifactStagingDirectory) - displayName: 'Create Artifacts for CustomOp' # libcustom_op_library.so from cpu build is built with fp8, ROCm does not support it. - - - template: templates/c-api-artifacts-package-and-publish-steps-posix.yml - parameters: - buildConfig: 'Release' - artifactName: 'onnxruntime-linux-x64-rocm-$(OnnxRuntimeVersion)' - artifactNameNoVersionString: 'onnxruntime-linux-x64-rocm' - libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' - - - template: templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - template: templates/clean-agent-build-directory-step.yml - -- stage: NuGet_Packaging_ROCm - dependsOn: - - Setup - - Linux_C_API_Packaging_ROCm_x64 - condition: succeeded() - jobs: - - job: NuGet_Packaging_ROCm - workspace: - clean: all - # we need to use a 2022 pool to create the nuget package with MAUI targets. - # VS2019 has no support for net6/MAUI and we need to use msbuild (from the VS install) to do the packing - pool: 'Onnxruntime-Win-CPU-2022' - variables: - breakCodesignValidationInjection: ${{ parameters.DoEsrp }} - ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] - - steps: - - checkout: self - submodules: true - fetchDepth: 1 - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - NuGet' - ArtifactName: 'onnxruntime-linux-x64-rocm' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - task: PowerShell@2 - displayName: 'Reconstruct Build Directory' - inputs: - targetType: inline - script: | - Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tgz | % { - # *.tar will be created after *.tgz is extracted - $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\nuget-artifact" - Write-Output $cmd - Invoke-Expression -Command $cmd - } - - Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tar | % { - $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" - Write-Output $cmd - Invoke-Expression -Command $cmd - } - - $ort_dirs = Get-ChildItem -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory - foreach ($ort_dir in $ort_dirs) - { - $dirname = Split-Path -Path $ort_dir -Leaf - $dirname = $dirname.SubString(0, $dirname.LastIndexOf('-')) - Write-Output "Renaming $ort_dir to $dirname" - Rename-Item -Path $ort_dir -NewName $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname - } - - Copy-Item -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-rocm\lib\* -Destination $(Build.BinariesDirectory)\RelWithDebInfo - - - script: | - tree /F - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Inspect Build Binaries Directory' - - - script: | - mklink /D /J models C:\local\models - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Create models link' - - - task: NuGetToolInstaller@0 - displayName: Use Nuget 6.10.x - inputs: - versionSpec: 6.10.x - - - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm"' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: MSBuild@1 - displayName: 'Build C# bindings' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: > - -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" - -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm" - -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - -p:IsLinuxBuild=true - -p:IsWindowsBuild=false - -p:IsMacOSBuild=false - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - template: templates/win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' - DisplayName: 'ESRP - Sign C# dlls' - DoEsrp: ${{ parameters.DoEsrp }} - - - task: UsePythonVersion@0 - displayName: 'Use Python' - inputs: - versionSpec: 3.12 - - - task: MSBuild@1 - displayName: 'Build Nuget Packages' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' - configuration: RelWithDebInfo - platform: 'Any CPU' - msbuildArguments: > - -t:CreatePackage - -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" - -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm - -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - -p:CurrentTime=$(BuildTime) - -p:CurrentDate=$(BuildDate) - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - Contents: '*.snupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - template: templates/esrp_nuget.yml - parameters: - DisplayName: 'ESRP - sign NuGet package' - FolderPath: '$(Build.ArtifactStagingDirectory)' - DoEsrp: ${{ parameters.DoEsrp }} - - - template: templates/validate-package.yml - parameters: - PackageType: 'nuget' - PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' - PlatformsSupported: 'linux-x64' - VerifyNugetSigning: false - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-ROCm' - targetPath: '$(Build.ArtifactStagingDirectory)' - - - task: MSBuild@1 - displayName: 'Clean C#' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - -- template: nuget/templates/test_linux.yml - parameters: - AgentPool: AMD-GPU - ArtifactSuffix: 'ROCm' - StageSuffix: 'ROCm' - NugetPackageName: 'Microsoft.ML.OnnxRuntime.ROCm' - SpecificArtifact: ${{ parameters.specificArtifact }} - CustomOpArtifactName: 'onnxruntime-linux-x64-rocm' - BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml b/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml deleted file mode 100644 index 1d2393d8f96d5..0000000000000 --- a/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml +++ /dev/null @@ -1,21 +0,0 @@ -resources: - pipelines: - - pipeline: build - source: 'Nuget ROCM Packaging pipeline' - trigger: - branches: - include: - - main - - rel-* - branch: main - -# ROCm -stages: -- template: templates/publish-nuget-steps.yml - parameters: - stage_name: 'Publish_ROCM_NuGet_Package' - download_artifacts_steps: - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-ROCm' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-ROCm\*" $(Build.BinariesDirectory)\nuget-artifact\final-package diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index 8fabb80a73869..5ae60aac8f9b4 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -96,18 +96,10 @@ stages: inputs: versionSpec: 6.10.x - - task: PowerShell@2 - displayName: Install MAUI workloads - inputs: - targetType: 'inline' - script: | - dotnet workload install android ios maccatalyst - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: MSBuild@1 displayName: 'Restore NuGet Packages and create project.assets.json' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu"' @@ -116,7 +108,7 @@ stages: - task: MSBuild@1 displayName: 'Build C# bindings' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' configuration: RelWithDebInfo platform: 'Any CPU' msbuildArguments: > @@ -208,7 +200,7 @@ stages: - task: MSBuild@1 displayName: 'Clean C#' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' @@ -223,4 +215,3 @@ stages: inputs: artifactName: 'drop-signed-nuget-GPU' targetPath: '$(Build.ArtifactStagingDirectory)' - diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml new file mode 100644 index 0000000000000..03802746cec3d --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml @@ -0,0 +1,76 @@ +parameters: +- name: DoEsrp + displayName: Run code sign tasks? Must be true if you are doing an Onnx Runtime release. + type: boolean + default: true + +stages: +- stage: NuGet_Packaging_QNN + pool: + name: 'Onnxruntime-QNNEP-Windows-2022-CPU' + dependsOn: + - OnnxRuntime_QNN_Nuget_Win_x64 + - OnnxRuntime_QNN_Nuget_Win_Arm64 + condition: succeeded() + jobs: + - job: NuGet_Packaging_QNN + workspace: + clean: all + steps: + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - QNN NuGet x64' + inputs: + artifactName: 'drop-nuget-qnn-x64' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact-x64' + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - QNN NuGet arm64' + inputs: + artifactName: 'drop-nuget-qnn-arm64' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64' + + - task: PowerShell@2 + displayName: 'Bundle NuGet' + inputs: + targetType: 'inline' + script: | + + $x64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-x64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) + $nuget_package_name = $x64_nupkgs[0].Name + $x64_nuget_package = $x64_nupkgs[0].FullName + + $nupkg_unzipped_directory = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget_unzip_merged', [System.IO.Path]::GetFileNameWithoutExtension($nuget_package_name)) + + $x64_unzip_cmd = "7z.exe x $x64_nuget_package -y -o$nupkg_unzipped_directory" + Invoke-Expression -Command $x64_unzip_cmd + + $arm64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-arm64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) + $arm64_nuget_package = $arm64_nupkgs[0].FullName + + $arm64_unzip_cmd = "7z.exe x $arm64_nuget_package -y -o$nupkg_unzipped_directory" + Invoke-Expression -Command $arm64_unzip_cmd + + $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget-artifact-merged') + if (!(Test-Path $merged_nuget_path)) { + New-Item -Path $merged_nuget_path -ItemType Directory + } + + $merged_zip = [System.IO.Path]::Combine($merged_nuget_path, 'qnn_nuget.zip') + $zip_cmd = "7z.exe a -r $merged_zip $nupkg_unzipped_directory/*" + Invoke-Expression -Command $zip_cmd + + $merged_nuget = [System.IO.Path]::Combine($merged_nuget_path, $nuget_package_name) + move $merged_zip $merged_nuget + workingDirectory: $(Build.BinariesDirectory) + + - template: ../templates/esrp_nuget.yml + parameters: + DisplayName: 'ESRP - sign NuGet package' + FolderPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' + DoEsrp: ${{ parameters.DoEsrp }} + + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Pipeline NuGet Artifact' + inputs: + artifactName: 'drop-signed-nuget-qnn' + targetPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 4ff539df9f914..5e783607e3622 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -123,7 +123,7 @@ stages: --skip_submodule_sync --cmake_generator "Visual Studio 17 2022" --enable_pybind - --enable_onnx_tests + --enable_onnx_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache ${{ parameters.build_py_parameters }} --parallel --use_binskim_compliant_compile_flags --update --build $(TelemetryOption) @@ -151,10 +151,11 @@ stages: Contents: '*.whl' TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: PublishBuildArtifacts@1 + - task: 1ES.PublishPipelineArtifact@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - ArtifactName: onnxruntime + artifactName: onnxruntime-win-$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' - script: | 7z x *.whl @@ -199,7 +200,9 @@ stages: workspace: clean: all pool: - vmImage: 'macOS-13' + name: "Azure Pipelines" + image: "macOS-13" + os: macOS variables: MACOSX_DEPLOYMENT_TARGET: '13.3' strategy: @@ -251,74 +254,81 @@ stages: Contents: '*.whl' TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: PublishBuildArtifacts@1 + - task: 1ES.PublishPipelineArtifact@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - ArtifactName: onnxruntime + artifactName: onnxruntime-macos-$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' - - ${{ if eq(parameters.enable_linux_arm, true) }}: - - stage: Python_Packaging_Linux_ARM - dependsOn: [] - jobs: - - template: ../templates/py-linux.yml - parameters: - arch: 'aarch64' - machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' - extra_build_arg: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - ${{ if eq(parameters.enable_linux_cpu, true) }}: - - stage: Python_Packaging_Linux_CPU - dependsOn: [] - jobs: +- ${{ if eq(parameters.enable_linux_arm, true) }}: + - stage: Python_Packaging_Linux_ARM + dependsOn: [] + jobs: - template: ../templates/py-linux.yml parameters: - arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' + arch: 'aarch64' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} + is1ES: true - - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: - - stage: Python_Packaging_Windows_ARM64_QNN - dependsOn: [] - jobs: - - template: ../templates/py-win-arm64-qnn.yml +- ${{ if eq(parameters.enable_linux_cpu, true) }}: + - stage: Python_Packaging_Linux_CPU + dependsOn: [] + jobs: + - template: ../templates/py-linux.yml + parameters: + arch: 'x86_64' + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + is1ES: true + +- ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: + - stage: Python_Packaging_Windows_ARM64_QNN + dependsOn: [] + jobs: + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + is1ES: true + +- ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: + - stage: Python_Packaging_Windows_arm64ec_QNN + dependsOn: [] + jobs: + - template: ../templates/py-win-arm64ec-qnn.yml parameters: - MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + is1ES: true - - ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: - - stage: Python_Packaging_Windows_arm64ec_QNN - dependsOn: [] - jobs: - - template: ../templates/py-win-arm64ec-qnn.yml - parameters: - MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QNN_SDK: ${{ parameters.qnn_sdk_version }} - BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - - - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: - - stage: Python_Packaging_Windows_x64_QNN - dependsOn: [] - jobs: - - template: ../templates/py-win-x64-qnn.yml - parameters: - MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QNN_SDK: ${{ parameters.qnn_sdk_version }} - BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - - - ${{ if eq(parameters.enable_linux_x64_qnn, true) }}: - - stage: Python_Packaging_Linux_x64_QNN - dependsOn: [] - jobs: - - template: ../templates/py-linux-qnn.yml +- ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: + - stage: Python_Packaging_Windows_x64_QNN + dependsOn: [] + jobs: + - template: ../templates/py-win-x64-qnn.yml parameters: - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' - extra_build_arg: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} + MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + is1ES: true + +- ${{ if eq(parameters.enable_linux_x64_qnn, true) }}: + - stage: Python_Packaging_Linux_x64_QNN + dependsOn: [] + jobs: + - template: ../templates/py-linux-qnn.yml + parameters: + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + is1ES: true diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml index 5ee425405ac70..e1a514ea54123 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml @@ -57,6 +57,22 @@ steps: copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.pdb $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.lib $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + # Copy WebGPU dependencies if required + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\dxcompiler.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\dxil.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + + # Copy QNN dependencies if required + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_qnn.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\libQnnHtp*.so $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib /Y + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\libqnnhtp*.cat $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib /Y + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnCpu.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtp.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpPrepare.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpV68Stub.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpV73Stub.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnSaver.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnSystem.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + # copy trt ep libraries only when trt ep is enabled copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_tensorrt.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_tensorrt.pdb $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml new file mode 100644 index 0000000000000..f6956b426ddfc --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml @@ -0,0 +1,64 @@ +parameters: + - name: OpenVINOVersion + type: string + default: '2025.0.0' + +steps: + - powershell: | + $Url = "https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.0/windows/openvino_toolkit_windows_2025.0.0.17942.1f68be9f594_x86_64.zip" + $OutputPath = "$env:Agent_TempDirectory\openvino.zip" + $ExtractPath = "$env:Agent_TempDirectory\openvino-v$env:OpenVINOVersion" + $TempExtractPath = "$env:Agent_TempDirectory\openvino_temp" + + # Ensure directories exist + if (Test-Path $ExtractPath) { + Remove-Item -Recurse -Force $ExtractPath + } + New-Item -ItemType Directory -Path $ExtractPath | Out-Null + New-Item -ItemType Directory -Path $TempExtractPath | Out-Null + + # Download OpenVINO ZIP + Write-Output "Downloading OpenVINO" + Invoke-WebRequest -Uri $Url -OutFile $OutputPath + + # Extract to temporary directory first + Write-Output "Extracting OpenVINO to a temporary directory" + Expand-Archive -Path $OutputPath -DestinationPath $TempExtractPath -Force + + # Locate the nested subdirectory + $InnerFolder = Get-ChildItem -Path $TempExtractPath -Directory | Select-Object -First 1 + + if ($InnerFolder) { + Write-Output "Moving extracted files to final destination" + Move-Item -Path "$($InnerFolder.FullName)\*" -Destination $ExtractPath -Force + } else { + Write-Error "Extraction failed: No expected subdirectory found in $TempExtractPath." + Write-Error "The archive may not have extracted correctly, or its structure is different than expected." + exit 1 + } + + # Clean up temporary files + Remove-Item -Recurse -Force $TempExtractPath + Remove-Item -Force $OutputPath + + # Confirm success + Write-Output "OpenVINO extracted to $ExtractPath" + displayName: 'Download OpenVINO Toolkit v${{ parameters.OpenVINOVersion }}' + env: + OpenVINOVersion: ${{ parameters.OpenVINOVersion }} + + - powershell: | + echo "##vso[task.setvariable variable=OpenVINORootDir]$(Agent.TempDirectory)\openvino-v${{ parameters.OpenVINOVersion }}" + displayName: 'Set OpenVINORootDir' + + - task: CmdLine@2 + inputs: + script: | + echo $(OpenVINORootDir) + displayName: 'Print OpenVINORootDir after downloading OpenVINO' + + - task: CmdLine@2 + displayName: 'Print contents of OpenVINO Toolkit' + inputs: + script: | + dir $(OpenVINORootDir) diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml index a4d5a73118ea2..2b73f82615bba 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml @@ -1,4 +1,8 @@ steps: +- task: NodeTool@0 + inputs: + # requires Node.js v22 for float16 testing (the V8 flag "--js-float16array") + versionSpec: '22.x' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)/js' @@ -11,6 +15,10 @@ steps: npm test workingDirectory: '$(Build.SourcesDirectory)/js/common' displayName: 'run onnxruntime-common tests' +- script: | + npm run test:f16 + workingDirectory: '$(Build.SourcesDirectory)/js/common' + displayName: 'run onnxruntime-common tests (enable Float16Array)' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)/js/web' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index 347a3145e8c70..8126cda449daa 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -6,10 +6,10 @@ parameters: type: string default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel - name: device type: string @@ -27,68 +27,82 @@ parameters: displayName: QNN SDK version type: string default: 2.31.0.250130 + +- name: is1ES + displayName: 'Whether the pipeline is running in 1ES' + type: boolean + default: false jobs: - job: Linux_py_qnn_Wheels_x64 timeoutInMinutes: 240 workspace: clean: all - pool: ${{ parameters.machine_pool }} + pool: + name: ${{ parameters.machine_pool }} + os: linux variables: - # The build machine pool doesn't have dotnet, so it can't run CG. - - name: skipComponentGovernanceDetection - value: true - - name: ORT_CACHE_DIR - value: $(Agent.TempDirectory)/ort_ccache - - name: TODAY - value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - - name: extra_build_args - ${{ if ne(parameters.extra_build_arg, '') }}: - value: -x ${{ parameters.extra_build_arg }} - ${{ if eq(parameters.extra_build_arg, '') }}: - value: '' + # The build machine pool doesn't have dotnet, so it can't run CG. + - name: skipComponentGovernanceDetection + value: true + - name: ORT_CACHE_DIR + value: $(Agent.TempDirectory)/ort_ccache + - name: TODAY + value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: -x ${{ parameters.extra_build_arg }} + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' steps: - - checkout: self - clean: true - submodules: none + - checkout: self + clean: true + submodules: none - - template: jobs/download_linux_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QnnSdk }} + - template: jobs/download_linux_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QnnSdk }} - - template: set-nightly-build-option-variable-step.yml + - template: set-nightly-build-option-variable-step.yml - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecpubuildpythonx86_64_qnn + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile + Context: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimecpubuildpythonx86_64_qnn - - template: linux-build-step-with-cache.yml - parameters: - WithCache: ${{parameters.with_cache}} - Today: $(TODAY) - AdditionalKey: Linux_py_qnn_Wheels_x64 - CacheDir: $(ORT_CACHE_DIR) - ChangeEveryCommit: true - BuildStep: - - task: Bash@3 - displayName: 'Build Python Wheel' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpythonx86_64_qnn -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) - env: - ADDITIONAL_DOCKER_PARAMETER: "--volume $(QnnSDKRootDir):/qnn_sdk" + - template: linux-build-step-with-cache.yml + parameters: + WithCache: ${{parameters.with_cache}} + Today: $(TODAY) + AdditionalKey: Linux_py_qnn_Wheels_x64 + CacheDir: $(ORT_CACHE_DIR) + ChangeEveryCommit: true + BuildStep: + - task: Bash@3 + displayName: 'Build Python Wheel' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh + arguments: -i onnxruntimecpubuildpythonx86_64_qnn -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) + env: + ADDITIONAL_DOCKER_PARAMETER: "--volume $(QnnSDKRootDir):/qnn_sdk" + - ${{ if eq(parameters.is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Artifact: Linux ONNXRuntime QNN python wheel' + inputs: + targetPath: '$(Build.BinariesDirectory)/dist' + artifactName: onnxruntime-linux-qnn-x64 - - task: PublishBuildArtifacts@1 + - ${{ if eq(parameters.is1ES, false) }}: + - task: PublishPipelineArtifact@1 displayName: 'Publish Artifact: Linux ONNXRuntime QNN python wheel' inputs: - PathtoPublish: '$(Build.BinariesDirectory)/dist' - ArtifactName: onnxruntime-linux-qnn-x64 + targetPath: '$(Build.BinariesDirectory)/dist' + artifactName: onnxruntime-linux-qnn-x64 - - template: component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml index e591b719ecfa9..8d0c4334f4874 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml @@ -9,10 +9,10 @@ parameters: type: string default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel - name: device type: string @@ -34,76 +34,98 @@ parameters: type: string default: '' +- name: is1ES + displayName: 'Whether the pipeline is running in 1ES' + type: boolean + default: false + jobs: - job: Linux_py_Wheels_${{ parameters.arch }}_${{parameters.ep}} timeoutInMinutes: 240 workspace: clean: all - pool: ${{ parameters.machine_pool }} + pool: + name: ${{ parameters.machine_pool }} + os: 'linux' + ${{ if eq(parameters.arch, 'aarch64') }}: + hostArchitecture: Arm64 variables: - # The build machine pool doesn't have dotnet, so it can't run CG. - - name: skipComponentGovernanceDetection - value: true - - name: ORT_CACHE_DIR - value: $(Agent.TempDirectory)/ort_ccache - - name: TODAY - value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - - name: extra_build_args - ${{ if ne(parameters.extra_build_arg, '') }}: - value: '-x ${{ parameters.extra_build_arg }}' - ${{ if eq(parameters.extra_build_arg, '') }}: - value: '' - - name: python_exe_path - ${{ if ne(parameters.python_exe_path, '') }}: - value: '-p ${{ parameters.python_exe_path }}' - ${{ if eq(parameters.python_exe_path, '') }}: - value: '' + # The build machine pool doesn't have dotnet, so it can't run CG. + - name: skipComponentGovernanceDetection + value: true + - name: ORT_CACHE_DIR + value: $(Agent.TempDirectory)/ort_ccache + - name: TODAY + value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: '-x ${{ parameters.extra_build_arg }}' + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' + - name: python_exe_path + ${{ if ne(parameters.python_exe_path, '') }}: + value: '-p ${{ parameters.python_exe_path }}' + ${{ if eq(parameters.python_exe_path, '') }}: + value: '' steps: - - checkout: self - clean: true - submodules: none - - - template: set-nightly-build-option-variable-step.yml - - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecpubuildpython${{ parameters.arch }} - - - template: linux-build-step-with-cache.yml - parameters: - WithCache: ${{parameters.with_cache}} - Today: $(TODAY) - AdditionalKey: Linux_py_Wheels_${{ parameters.arch }} - CacheDir: $(ORT_CACHE_DIR) - ChangeEveryCommit: true - BuildStep: - - task: Bash@3 - displayName: 'Build Python Wheel' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) $(python_exe_path) - ${{ if eq(parameters.with_cache, 'true') }}: - env: - ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" - - - task: PublishBuildArtifacts@1 + - checkout: self + clean: true + submodules: none + + - template: set-nightly-build-option-variable-step.yml + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu/Dockerfile + Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimecpubuildpython${{ parameters.arch }} + + - template: linux-build-step-with-cache.yml + parameters: + WithCache: ${{parameters.with_cache}} + Today: $(TODAY) + AdditionalKey: Linux_py_Wheels_${{ parameters.arch }} + CacheDir: $(ORT_CACHE_DIR) + ChangeEveryCommit: true + BuildStep: + - task: Bash@3 + displayName: 'Build Python Wheel' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh + arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) $(python_exe_path) + ${{ if eq(parameters.with_cache, 'true') }}: + env: + ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" + + - ${{ if eq(parameters.is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - PathtoPublish: '$(Build.BinariesDirectory)/dist' - ArtifactName: onnxruntime-${{ parameters.ep }} - - - task: PublishPipelineArtifact@0 + targetPath: '$(Build.BinariesDirectory)/dist' + artifactName: onnxruntime-${{ parameters.arch }}-${{ parameters.ep }} + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Test Binaries' + inputs: + artifactName: 'drop-linux-cpu-${{ parameters.arch }}-${{ parameters.ep }}' + targetPath: '$(Build.BinariesDirectory)/${{ parameters.cmake_build_type }}' + - ${{ if eq(parameters.is1ES, false) }}: + - task: PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + targetPath: '$(Build.BinariesDirectory)/dist' + artifactName: onnxruntime-${{ parameters.arch }}-${{ parameters.ep }} + - task: PublishPipelineArtifact@1 displayName: 'Publish Test Binaries' inputs: artifactName: 'drop-linux-cpu-${{ parameters.arch }}-${{ parameters.ep }}' targetPath: '$(Build.BinariesDirectory)/${{ parameters.cmake_build_type }}' - - template: component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' + + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index 3a3da0f8f5afa..c0bd740b2d483 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -9,9 +9,13 @@ parameters: - name: machine_pool type: object -- name: python_arch +- name: ep type: string - default: 'x64' + default: 'cpu' + +- name: arch + type: string + default: 'x86_64' jobs: - job: ${{ parameters.job_name }} @@ -37,10 +41,9 @@ jobs: displayName: 'Use Python' inputs: versionSpec: $(PythonVersion) - architecture: ${{ parameters.python_arch }} - download: build # pipeline resource identifier. - artifact: 'onnxruntime' + artifact: 'onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}' - task: Bash@3 inputs: @@ -51,7 +54,7 @@ jobs: FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime" $PYTHON_PACKAGE_NAME + python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}" $PYTHON_PACKAGE_NAME python3 -m pip show $PYTHON_PACKAGE_NAME python3 -c "import onnxruntime as ort; print(ort.__version__)" workingDirectory: $(Pipeline.Workspace)/build/onnxruntime diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml index c475feaef0018..eef97341b8d53 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml @@ -19,10 +19,10 @@ parameters: type: string default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel - name: timeout type: number @@ -50,29 +50,31 @@ jobs: artifact: 'drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}' - download: current # pipeline resource identifier. - artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}-${{ parameters.ep }}' + artifact: 'onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}' - bash: | set -e -x mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} - mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}-${{parameters.ep}}" "$(Build.BinariesDirectory)/whl" + mv "$(Pipeline.Workspace)/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + displayName: 'Move the artifacts to the binaries directory' # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - download: build # pipeline resource identifier. artifact: 'drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}' - download: build # pipeline resource identifier. - artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}-${{ parameters.ep }}' + artifact: 'onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}' - bash: | set -e -x ls $(Pipeline.Workspace)/build mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} - mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}-${{parameters.ep}}" "$(Build.BinariesDirectory)/whl" + mv "$(Pipeline.Workspace)/build/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + displayName: 'Move the artifacts to the binaries directory' # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet - ${{ if eq(parameters.arch, 'x86_64') }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 4c9d0dccaf48d..10ea7f6203bb1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -19,6 +19,11 @@ parameters: type: string default: '' +- name: is1ES + displayName: 'Whether the pipeline is running in 1ES' + type: boolean + default: false + jobs: - job: Win_py_arm64_qnn_Wheels timeoutInMinutes: 210 @@ -26,6 +31,8 @@ jobs: clean: all pool: name: ${{ parameters.MACHINE_POOL }} + os: windows + hostArchitecture: Arm64 strategy: matrix: Python311_arm64: @@ -41,132 +48,140 @@ jobs: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - - checkout: self - clean: true - submodules: recursive - - - template: telemetry-steps.yml - - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - displayName: Copy python $(PythonVersion) version to agent tools directory - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: 'arm64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - task: onebranch.pipeline.tsaoptions@1 - displayName: 'OneBranch TSAOptions' - inputs: - tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' - appendSourceBranchName: false - - - task: PythonScript@0 - inputs: - scriptSource: inline - script: | - import subprocess - subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel']) - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Install python modules' - - - template: set-nightly-build-option-variable-step.yml - - - template: jobs/download_win_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QNN_SDK }} - - - task: PythonScript@0 - displayName: 'Generate cmake config' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: > - --config RelWithDebInfo - --build_dir $(Build.BinariesDirectory) - --skip_submodule_sync - --cmake_generator "$(VSGenerator)" - --build_shared_lib - --use_qnn - --qnn_home $(QnnSDKRootDir) - --enable_pybind - --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update - $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} - workingDirectory: '$(Build.BinariesDirectory)' - - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' - platform: 'arm64' - configuration: RelWithDebInfo - msbuildArchitecture: 'arm64' - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - createLogFile: true - - # Esrp signing - - template: win-esrp-dll.yml - parameters: - FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' - DisplayName: 'ESRP - Sign Native dlls' - DoEsrp: true - Pattern: '*.pyd' - - - task: PythonScript@0 - displayName: 'Build wheel' - inputs: - scriptPath: '$(Build.SourcesDirectory)\setup.py' - arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' - workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' - Contents: '*.whl' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - ArtifactName: onnxruntime_qnn_arm64 - - - script: | - 7z x *.whl - workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'unzip the package' - - - task: CredScan@3 - displayName: 'Run CredScan' - inputs: - debugMode: false - continueOnError: true - - - task: BinSkim@4 - displayName: 'Run BinSkim' - inputs: - AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 + XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ + COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete + DIR $(Agent.ToolsDirectory)\Python + DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) + DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 + displayName: Copy python $(PythonVersion) version to agent tools directory + + - task: UsePythonVersion@0 + inputs: + versionSpec: $(PythonVersion) + addToPath: true + architecture: 'arm64' + + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel']) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: set-nightly-build-option-variable-step.yml + + - template: jobs/download_win_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QNN_SDK }} + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --build_shared_lib + --use_qnn + --qnn_home $(QnnSDKRootDir) + --enable_pybind + --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'arm64' + configuration: RelWithDebInfo + msbuildArchitecture: 'arm64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - ${{ if eq(parameters.is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + artifactName: onnxruntime_qnn_arm64_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' + - ${{ if eq(parameters.is1ES, false) }}: + - task: PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + input: + artifactName: onnxruntime_qnn_arm64_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index ed29f1e67515e..24321d2a3e1ec 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -19,6 +19,11 @@ parameters: type: string default: '' +- name: is1ES + displayName: 'Whether the pipeline is running in 1ES' + type: boolean + default: false + jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 @@ -26,6 +31,7 @@ jobs: clean: all pool: name: ${{ parameters.MACHINE_POOL }} + os: windows strategy: matrix: Python310_x64: @@ -40,117 +46,124 @@ jobs: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - - checkout: self - clean: true - submodules: recursive - - - template: telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: 'x64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - task: onebranch.pipeline.tsaoptions@1 - displayName: 'OneBranch TSAOptions' - inputs: - tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' - appendSourceBranchName: fals - - - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\linux\python\requirements.txt - - - - template: set-nightly-build-option-variable-step.yml - - - template: jobs/download_win_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QNN_SDK }} - - - task: PythonScript@0 - displayName: 'Generate cmake config' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: > - --config RelWithDebInfo - --build_dir $(Build.BinariesDirectory) - --skip_submodule_sync - --cmake_generator "$(VSGenerator)" - --build_shared_lib - --use_qnn - --qnn_home $(QnnSDKRootDir) - --enable_pybind - --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update --arm64ec - $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} - workingDirectory: '$(Build.BinariesDirectory)' - - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' - platform: 'arm64ec' - configuration: RelWithDebInfo - msbuildArchitecture: 'x64' - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - createLogFile: true - - # Esrp signing - - template: win-esrp-dll.yml - parameters: - FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' - DisplayName: 'ESRP - Sign Native dlls' - DoEsrp: true - Pattern: '*.pyd' - - - task: PythonScript@0 - displayName: 'Build wheel' - inputs: - scriptPath: '$(Build.SourcesDirectory)\setup.py' - arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' - workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' - Contents: '*.whl' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - ArtifactName: onnxruntime_qnn_arm64ec - - - script: | - 7z x *.whl - workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'unzip the package' - - - task: CredScan@3 - displayName: 'Run CredScan' - inputs: - debugMode: false - continueOnError: true - - - task: BinSkim@4 - displayName: 'Run BinSkim' - inputs: - AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - task: UsePythonVersion@0 + inputs: + versionSpec: $(PythonVersion) + addToPath: true + architecture: 'x64' + + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: fals + + - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\linux\python\requirements.txt + + + - template: set-nightly-build-option-variable-step.yml + + - template: jobs/download_win_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QNN_SDK }} + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --build_shared_lib + --use_qnn + --qnn_home $(QnnSDKRootDir) + --enable_pybind + --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update --arm64ec + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'arm64ec' + configuration: RelWithDebInfo + msbuildArchitecture: 'x64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - ${{ if eq(parameters.is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + artifactName: onnxruntime_qnn_arm64ec_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' + - ${{ if eq(parameters.is1ES, false) }}: + - task: PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + artifactName: onnxruntime_qnn_arm64ec_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 13069846da342..175b343e55d57 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -19,6 +19,11 @@ parameters: type: string default: '' +- name: is1ES + displayName: 'Whether the pipeline is running in 1ES' + type: boolean + default: false + jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 @@ -116,10 +121,18 @@ jobs: Contents: '*.whl' TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - ArtifactName: onnxruntime_qnn_x64 + - ${{ if eq(parameters.is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + artifactName: onnxruntime_qnn_x64_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' + - ${{ if eq(parameters.is1ES, false) }}: + - task: PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + artifactName: onnxruntime_qnn_x64_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' - script: | 7z x *.whl diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index a93d6b5ff8419..3fa4799ec9c0e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -10,6 +10,8 @@ parameters: buildPlatform: 'x64' buildArch: 'x64' StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' + Is1ES: true + PublishArchive: false stages: - stage: ${{ parameters.StageName }} @@ -18,7 +20,8 @@ stages: - job: ${{ parameters.StageName }} timeoutInMinutes: 120 - pool: ${{ parameters.qnn_ep_build_pool_name }} + pool: + name: ${{ parameters.qnn_ep_build_pool_name }} variables: ${{ if eq(parameters.buildArch, 'ARM64') }}: targetArchitecture: 'arm64' @@ -28,133 +31,148 @@ stages: commonBuildArgs: '--update --compile_no_warning_as_error --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_binskim_compliant_compile_flags ${{ parameters.buildParameter }} ' steps: - - template: set-version-number-variables-step.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - - - template: jobs/download_win_qnn_sdk.yml + - template: set-version-number-variables-step.yml + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + + - template: jobs/download_win_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QnnSdk }} + + - task: PythonScript@0 + displayName: 'Generate project' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '--use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' + + - task: VSBuild@1 + displayName: 'Build onnxruntime' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: VSBuild@1 + displayName: 'Build onnx_test_runner' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnx_test_runner.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: VSBuild@1 + displayName: 'Build onnxruntime_perf_test' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_perf_test.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: VSBuild@1 + displayName: 'Build onnxruntime_test_all (to copy Qnn libs)' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_test_all.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: CmdLine@2 + displayName: 'Print contents of binaries directory' + inputs: + script: | + dir $(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }} + + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' + DisplayName: 'ESRP - Sign dlls' + DoEsrp: ${{ parameters.DoEsrp }} + Pattern: 'onnxruntime*.dll' + + - ${{ if eq(parameters.PublishArchive, true) }}: + - template: c-api-artifacts-package-and-publish-steps-windows.yml parameters: - QnnSDKVersion: ${{ parameters.QnnSdk }} - - - task: PythonScript@0 - displayName: 'Generate project' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' - - - task: VSBuild@1 - displayName: 'Build onnxruntime' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnx_test_runner' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnx_test_runner.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnxruntime_perf_test' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_perf_test.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnxruntime_test_all (to copy Qnn libs)' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_test_all.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: CmdLine@2 - displayName: 'Print contents of binaries directory' - inputs: - script: | - dir $(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }} + buildConfig: ${{ parameters.build_config }} + artifactName: 'onnxruntime-win-${{ parameters.buildPlatform }}-qnn' + artifactNameNoVersionString: 'onnxruntime-win-${{ parameters.buildPlatform }}-qnn' + DoEsrp: ${{ parameters.DoEsrp }} + - task: MSBuild@1 + displayName: 'Restore NuGet Packages and create project.assets.json' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' + platform: 'Any CPU' + configuration: ${{ parameters.build_config }} + msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: MSBuild@1 + displayName: 'Build C# bindings' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' + platform: 'Any CPU' + configuration: ${{ parameters.build_config }} + msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - ${{ if eq(parameters.DoEsrp, true) }}: - template: win-esrp-dll.yml parameters: - FolderPath: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' - DisplayName: 'ESRP - Sign dlls' + FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\${{ parameters.build_config }}' + DisplayName: 'ESRP - Sign C# dlls' DoEsrp: ${{ parameters.DoEsrp }} - Pattern: 'onnxruntime*.dll' - - - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' - platform: 'Any CPU' - configuration: ${{ parameters.build_config }} - msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: MSBuild@1 - displayName: 'Build C# bindings' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' - platform: 'Any CPU' - configuration: ${{ parameters.build_config }} - msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - ${{ if eq(parameters.DoEsrp, true) }}: - - template: win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\${{ parameters.build_config }}' - DisplayName: 'ESRP - Sign C# dlls' - DoEsrp: ${{ parameters.DoEsrp }} - - - task: MSBuild@1 - displayName: 'Build Nuget Packages' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' - platform: 'Any CPU' - configuration: ${{ parameters.build_config }} - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=$(targetArchitecture)' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: CopyFiles@2 - displayName: 'Copy native nuget package to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: CopyFiles@2 - displayName: 'Copy native nuget symbols package to: $(Build.ArtifactStagingDirectory)' + - task: MSBuild@1 + displayName: 'Build Nuget Packages' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' + platform: 'Any CPU' + configuration: ${{ parameters.build_config }} + msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=$(targetArchitecture)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: CopyFiles@2 + displayName: 'Copy native nuget package to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy native nuget symbols package to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' + Contents: '*.snupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - ${{ if eq(parameters.Is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Pipeline x64 NuGet Artifact' inputs: - SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' - Contents: '*.snupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishPipelineArtifact@0 + artifactName: ${{ parameters.ArtifactName }} + targetPath: '$(Build.ArtifactStagingDirectory)' + - ${{ else }}: + - task: PublishPipelineArtifact@1 displayName: 'Publish Pipeline x64 NuGet Artifact' inputs: artifactName: ${{ parameters.ArtifactName }} diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 7991916a47ca4..52dbb76632e0c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -62,10 +62,14 @@ stages: dependsOn: '${{parameters.InitialStageDependsOn}}' jobs: - job: ReactNative_CI_iOS - pool: - name: 'Azure Pipelines' - image: 'macOS-13' - os: 'macOS' + ${{ if eq(parameters.is1ES, false) }}: + pool: + vmImage: 'macOS-13' + ${{ if eq(parameters.is1ES, true) }}: + pool: + name: 'Azure Pipelines' + image: 'macOS-13' + os: 'macOS' timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index 87836880cbdb8..2e3589ee87c29 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -83,9 +83,6 @@ stages: git submodule update --init -- cmake/external/onnx workingDirectory: '$(Build.SourcesDirectory)' displayName: 'Checkout submodule onnx' - - task: NodeTool@0 - inputs: - versionSpec: '20.x' - template: linux-web-init-and-check.yml - task: Bash@3 displayName: 'Extract commit SHA and save to __commit.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 600e6d857185f..69a06c3db24b8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -161,7 +161,7 @@ stages: displayName: 'Generate cmake config' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --build --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} $(timeoutParameter) $(buildJavaParameter)' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --build --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} $(timeoutParameter) $(buildJavaParameter)' workingDirectory: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index b77cab6a19ba0..6868043f64d81 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -88,10 +88,18 @@ jobs: inputs: sourceFolder: $(Pipeline.Workspace)\artifacts contents: | - **\*.* + **\ort-*.wasm targetFolder: $(Build.SourcesDirectory)\js\web\dist flattenFolders: true - displayName: 'Binplace dist files' + displayName: 'Binplace dist files (.wasm)' + - task: CopyFiles@2 + inputs: + sourceFolder: $(Pipeline.Workspace)\artifacts + contents: | + **\ort-*.mjs + targetFolder: $(Build.SourcesDirectory)\js\web\dist + flattenFolders: true + displayName: 'Binplace dist files (.mjs)' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml index e201cc0ffdd5a..00df695889b1d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml @@ -44,10 +44,18 @@ jobs: inputs: sourceFolder: $(Pipeline.Workspace)\artifacts contents: | - **\*.* + **\ort-*.wasm targetFolder: $(Build.SourcesDirectory)\js\web\dist flattenFolders: true - displayName: 'Binplace dist files' + displayName: 'Binplace dist files (.wasm)' + - task: CopyFiles@2 + inputs: + sourceFolder: $(Pipeline.Workspace)\artifacts + contents: | + **\ort-*.mjs + targetFolder: $(Build.SourcesDirectory)\js\web\dist + flattenFolders: true + displayName: 'Binplace dist files (.mjs)' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js' diff --git a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml index fb3ebdc760a7b..355a575307f0b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml @@ -89,7 +89,7 @@ jobs: # must call vsdevcmd first to add cmake to PATH - script: | python --version - python "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" + python "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --use_vcpkg --use_vcpkg_ms_internal_asset_cache --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml index bb6c210161952..a0f22fcfce14e 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml @@ -105,3 +105,31 @@ stages: onnxruntime_webgpu_external_dawn_test.exe --no_proc_table displayName: Run tests (onnxruntime_webgpu_external_dawn_test) workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + +- stage: webgpu_minimal_build_edge + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env.bat + buildArch: x64 + additionalBuildFlags: >- + --build_shared_lib + --disable_exceptions + --disable_rtti + --enable_msvc_static_runtime + --enable_reduced_operator_type_support + --skip_tests + --use_binskim_compliant_compile_flags + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF onnxruntime_DISABLE_SPARSE_TENSORS=ON onnxruntime_DISABLE_OPTIONAL_TYPE=ON + --minimal_build extended + --use_webgpu + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: false + ORT_EP_NAME: WebGPU + EnablePython: false + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-VS2022-webgpu-A10 diff --git a/tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml new file mode 100644 index 0000000000000..f95ac526886fa --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml @@ -0,0 +1,116 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +### please do rerun set-trigger-rules.py ### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +jobs: +- job: 'BUILD_OPENVINO_EP' + pool: 'onnxruntime-Win-CPU-2022' + variables: + MsbuildArguments: '-detailedsummary -maxcpucount -consoleloggerparameters:PerformanceSummary' + OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + buildArch: x64 + setVcvars: true + BuildConfig: 'RelWithDebInfo' + ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + timeoutInMinutes: 240 + workspace: + clean: all + steps: + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + architecture: $(buildArch) + + - template: templates/jobs/download_win_openvino.yml + + - powershell: | + Write-Output "Setting up OpenVINO environment variables" + . "$(OpenVINORootDir)\setupvars.ps1" + + Write-Output "Exporting selected environment variables to pipeline" + + $vars = @( + "INTEL_OPENVINO_DIR", + "OpenVINO_DIR", + "OpenVINOGenAI_DIR", + "OPENVINO_LIB_PATHS", + "TBB_DIR", + "PATH", + "PYTHONPATH" + ) + + foreach ($var in $vars) { + if (Test-Path "Env:$var") { + $value = [System.Environment]::GetEnvironmentVariable($var, "Process") + Write-Output "Setting $var" + Write-Output "##vso[task.setvariable variable=$var;]$value" + } else { + Write-Output "Warning: $var is not set." + } + } + + Write-Output "Selected environment variables exported successfully" + displayName: 'Set up OpenVINO environment' + + - template: templates/jobs/win-ci-build-steps.yml + parameters: + WithCache: True + Today: $(TODAY) + AdditionalKey: "win-openvino | $(BuildConfig)" + BuildPyArguments: >- + --config $(BuildConfig) + --build_dir $(Build.BinariesDirectory) + --cmake_generator "Visual Studio 17 2022" + --build_shared_lib + --use_openvino CPU + --use_binskim_compliant_compile_flags + --update --parallel + MsbuildArguments: $(MsbuildArguments) + BuildArch: $(buildArch) + Platform: 'x64' + BuildConfig: $(BuildConfig) + + - powershell: | + Write-Output "Getting CPU information" + Get-WmiObject Win32_Processor | Select-Object Name, NumberOfCores, NumberOfLogicalProcessors, Architecture | Format-Table -AutoSize + + Write-Output "Starting unit tests" + python "$(Build.SourcesDirectory)\tools\ci_build\build.py" ` + --config "$(BuildConfig)" ` + --build_dir "$(Build.BinariesDirectory)" ` + --cmake_generator "Visual Studio 17 2022" ` + --build_shared_lib ` + --use_openvino CPU ` + --use_binskim_compliant_compile_flags ` + --test --enable_onnx_tests + displayName: 'Run unit tests' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index e08d7eb2b12de..1c3d911fa7dbb 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -90,7 +90,7 @@ jobs: --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --cmake_generator "Visual Studio 17 2022" - --build_shared_lib + --build_shared_lib --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_qnn $(QnnLibKind) --qnn_home $(QnnSDKRootDir) --update --build --parallel diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 81de3335a07d2..faef469e010f6 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -78,7 +78,7 @@ jobs: --build_dir $(Build.BinariesDirectory) --cmake_generator "Visual Studio 17 2022" --build_java - --build_shared_lib + --build_shared_lib --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_qnn $(QnnLibKind) --qnn_home $(QnnSDKRootDir) --use_binskim_compliant_compile_flags diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index 78f59452d1284..899aaaa95216a 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -16,8 +16,6 @@ "android-x86_64-crosscompile-ci-pipeline.yml", "bigmodels-ci-pipeline.yml", "linux-ci-pipeline.yml", - "linux-cpu-aten-pipeline.yml", - "linux-cpu-eager-pipeline.yml", "linux-dnnl-ci-pipeline.yml", "linux-gpu-ci-pipeline.yml", "linux-gpu-tensorrt-ci-pipeline.yml", @@ -36,6 +34,7 @@ "win-gpu-doc-gen-ci-pipeline.yml", "win-gpu-tensorrt-ci-pipeline.yml", "win-gpu-webgpu-ci-pipeline.yml", + "win-openvino-ci-pipeline.yml", "win-qnn-arm64-ci-pipeline.yml", "win-qnn-ci-pipeline.yml", ] diff --git a/tools/nuget/generate_nuspec_for_custom_nuget.py b/tools/nuget/generate_nuspec_for_custom_nuget.py new file mode 100644 index 0000000000000..baf46743cbf1b --- /dev/null +++ b/tools/nuget/generate_nuspec_for_custom_nuget.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import glob +import os +import shutil + +from generate_nuspec_for_native_nuget import generate_metadata + + +def generate_files(lines, args): + files_list = [""] + platform_map = { + "win-arm64": args.win_arm64, + "win-x64": args.win_x64, + } + + avoid_keywords = {"pdb"} + processed_includes = set() + for platform, platform_dir in platform_map.items(): + for file in glob.glob(os.path.join(platform_dir, "lib", "*")): + if not os.path.isfile(file): + continue + if any(keyword in file for keyword in avoid_keywords): + continue + file_name = os.path.basename(file) + + files_list.append(f'') + + for file in glob.glob(os.path.join(platform_dir, "include", "*")): + if not os.path.isfile(file): + continue + file_name = os.path.basename(file) + if file_name in processed_includes: + continue + processed_includes.add(file_name) + files_list.append(f'') + + files_list.append( + f'' + ) + + files_list.append(f'') + files_list.append( + f'' + ) + files_list.append(f'') + files_list.append( + f'' + ) + + source_props = os.path.join( + args.root_dir, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + "props.xml", + ) + target_props = os.path.join( + args.root_dir, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + f"{args.package_name}.props", + ) + shutil.copyfile(source_props, target_props) + files_list.append(f'') + files_list.append(f'') + + source_targets = os.path.join( + args.root_dir, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + "targets.xml", + ) + target_targets = os.path.join( + args.root_dir, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + f"{args.package_name}.targets", + ) + shutil.copyfile(source_targets, target_targets) + files_list.append(f'') + files_list.append(f'') + + files_list.append("") + lines.extend(files_list) + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Create a nuspec file for the custom nuget package.", + ) + + parser.add_argument("--nuspec_path", required=True, help="Nuspec output file path.") + parser.add_argument("--root_dir", required=True, help="ORT repository root.") + parser.add_argument( + "--commit_id", + required=True, + help="The last commit id included in this package.", + ) + parser.add_argument("--win_arm64", required=True, help="Ort win-arm64 directory") + parser.add_argument("--win_x64", required=True, help="Ort win-x64 directory") + parser.add_argument("--package_version", required=True, help="Version of the package") + parser.add_argument("--package_name", required=True, help="Name of the package") + + args = parser.parse_args() + + args.sdk_info = "" + + return args + + +def generate_nuspec(args: argparse.Namespace): + lines = [''] + lines.append("") + + generate_metadata(lines, args) + generate_files(lines, args) + + lines.append("") + return lines + + +def main(): + args = parse_arguments() + + lines = generate_nuspec(args) + + with open(os.path.join(args.nuspec_path), "w") as f: + for line in lines: + # Uncomment the printing of the line if you need to debug what's produced on a CI machine + print(line) + f.write(line) + f.write("\n") + + +if __name__ == "__main__": + main() diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index 1546a9143831a..aca5f1df7d18b 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -24,6 +24,7 @@ def get_pipeline_names(): "Windows GPU DML CI Pipeline", "Windows GPU Doc Gen CI Pipeline", "Windows GPU TensorRT CI Pipeline", + "Windows OpenVINO CI Pipeline", "ONNX Runtime Web CI Pipeline", "Win_TRT_Minimal_CUDA_Test_CI", # linux diff --git a/tools/python/util/__init__.py b/tools/python/util/__init__.py index a669963e84bcf..8631218ca9e00 100644 --- a/tools/python/util/__init__.py +++ b/tools/python/util/__init__.py @@ -7,7 +7,8 @@ from .run import run # noqa: F401 from .vcpkg_helpers import ( # noqa: F401 generate_android_triplets, - generate_posix_triplets, + generate_linux_triplets, + generate_macos_triplets, generate_vcpkg_triplets_for_emscripten, generate_windows_triplets, ) diff --git a/tools/python/util/vcpkg_helpers.py b/tools/python/util/vcpkg_helpers.py index d33b2f7675690..875a6186e55c2 100644 --- a/tools/python/util/vcpkg_helpers.py +++ b/tools/python/util/vcpkg_helpers.py @@ -222,6 +222,7 @@ def generate_triplet_for_posix_platform( enable_asan: bool, crt_linkage: str, target_abi: str, + osx_deployment_target: str, ) -> None: """ Generate triplet file for POSIX platforms (Linux, macOS). @@ -235,6 +236,7 @@ def generate_triplet_for_posix_platform( enable_asan (bool): Flag indicating if AddressSanitizer is enabled. crt_linkage (str): The CRT linkage type ("static" or "dynamic"). target_abi (str): The target ABI, which maps to the VCPKG_TARGET_ARCHITECTURE variable. Valid options include x86, x64, arm, arm64, arm64ec, s390x, ppc64le, riscv32, riscv64, loongarch32, loongarch64, mips64. + osx_deployment_target (str, optional): The macOS deployment target version. The parameter sets the minimum macOS version for compiled binaries. It also changes what versions of the macOS platform SDK CMake will search for. See the CMake documentation for CMAKE_OSX_DEPLOYMENT_TARGET for more information. """ folder_name_parts = [] if enable_asan: @@ -341,6 +343,8 @@ def generate_triplet_for_posix_platform( else: osx_abi = target_abi f.write(f'set(VCPKG_OSX_ARCHITECTURES "{osx_abi}")\n') + if osx_deployment_target: + f.write(f'set(VCPKG_OSX_DEPLOYMENT_TARGET "{osx_deployment_target}")\n') f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") f.write( "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" @@ -501,32 +505,58 @@ def generate_windows_triplets(build_dir: str) -> None: add_port_configs(f, enable_exception, False) -def generate_posix_triplets(build_dir: str) -> None: +def generate_linux_triplets(build_dir: str) -> None: """ - Generate triplet files for POSIX platforms (Linux, macOS). + Generate triplet files for Linux platforms. Args: build_dir (str): The directory to save the generated triplet files. """ - for os_name in ["linux", "osx"]: - if os_name == "linux": - target_abis = ["x86", "x64", "arm", "arm64", "s390x", "ppc64le", "riscv64", "loongarch64", "mips64"] - else: - target_abis = ["x64", "arm64", "universal2"] - for enable_rtti in [True, False]: - for enable_exception in [True, False]: - for enable_binskim in [True, False]: - for enable_asan in [True, False]: - if enable_asan and enable_binskim: - continue - for target_abi in target_abis: - generate_triplet_for_posix_platform( - build_dir, - os_name, - enable_rtti, - enable_exception, - enable_binskim, - enable_asan, - "dynamic", - target_abi, - ) + target_abis = ["x86", "x64", "arm", "arm64", "s390x", "ppc64le", "riscv64", "loongarch64", "mips64"] + for enable_rtti in [True, False]: + for enable_exception in [True, False]: + for enable_binskim in [True, False]: + for enable_asan in [True, False]: + if enable_asan and enable_binskim: + continue + for target_abi in target_abis: + generate_triplet_for_posix_platform( + build_dir, + "linux", + enable_rtti, + enable_exception, + enable_binskim, + enable_asan, + "dynamic", + target_abi, + None, + ) + + +def generate_macos_triplets(build_dir: str, osx_deployment_target: str) -> None: + """ + Generate triplet files for macOS platforms. + + Args: + build_dir (str): The directory to save the generated triplet files. + osx_deployment_target (str, optional): The macOS deployment target version. + """ + target_abis = ["x64", "arm64", "universal2"] + for enable_rtti in [True, False]: + for enable_exception in [True, False]: + for enable_binskim in [True, False]: + for enable_asan in [True, False]: + if enable_asan and enable_binskim: + continue + for target_abi in target_abis: + generate_triplet_for_posix_platform( + build_dir, + "osx", + enable_rtti, + enable_exception, + enable_binskim, + enable_asan, + "dynamic", + target_abi, + osx_deployment_target, + ) diff --git a/winml/adapter/winml_adapter_model.cpp b/winml/adapter/winml_adapter_model.cpp index 195bf6e5f0ffd..cf02c6fa2328b 100644 --- a/winml/adapter/winml_adapter_model.cpp +++ b/winml/adapter/winml_adapter_model.cpp @@ -593,13 +593,13 @@ ORT_API_STATUS_IMPL( input.set_name(input_name); if (info->type == ONNXType::ONNX_TYPE_TENSOR) { - auto num_dims = info->data->shape.NumDimensions(); + auto num_dims = info->tensor_type_info->shape.NumDimensions(); CreateTypeProto_Tensor( input.mutable_type()->mutable_tensor_type(), input_name, - (num_dims == 0) ? nullptr : &info->data->shape[0], + (num_dims == 0) ? nullptr : &info->tensor_type_info->shape[0], num_dims, - ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type) + ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type) ); } return nullptr; @@ -619,12 +619,12 @@ ORT_API_STATUS_IMPL( ONNX_NAMESPACE::TensorProto& input = *graph.add_initializer(); input.set_name(input_name); - auto num_dims = info->data->shape.NumDimensions(); + auto num_dims = info->tensor_type_info->shape.NumDimensions(); for (size_t i = 0; i < num_dims; i++) { - input.add_dims(info->data->shape[i]); + input.add_dims(info->tensor_type_info->shape[i]); } - input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)); + input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type)); auto tensor = value->GetMutable(); input.set_raw_data(tensor->DataRaw(), tensor->SizeInBytes()); @@ -645,9 +645,9 @@ ORT_API_STATUS_IMPL( CreateTypeProto_Tensor( output.mutable_type()->mutable_tensor_type(), output_name, - &info->data->shape[0], - info->data->shape.NumDimensions(), - ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type) + &info->tensor_type_info->shape[0], + info->tensor_type_info->shape.NumDimensions(), + ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type) ); } return nullptr; From cdc209cd386e97fe4aedf32c6c4595104aee2f56 Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Mon, 10 Mar 2025 12:15:46 +0530 Subject: [PATCH 009/138] Revert "Rebasing with msft commits (#607)" This reverts commit a6cdf62176c116e3a1e07f6cec1681c041d653b9. --- ThirdPartyNotices.txt | 35 - cmake/deps.txt | 1 + .../external/onnxruntime_external_deps.cmake | 54 +- cmake/nuget_helpers.cmake | 2 +- cmake/onnxruntime_framework.cmake | 5 +- cmake/onnxruntime_optimizer.cmake | 1 - cmake/onnxruntime_providers_js.cmake | 6 +- cmake/onnxruntime_python.cmake | 2 +- cmake/onnxruntime_session.cmake | 1 - cmake/onnxruntime_unittests.cmake | 43 +- cmake/onnxruntime_webassembly.cmake | 37 +- cmake/patches/dawn/dawn.patch | 113 +-- cmake/winml_sdk_helpers.cmake | 2 +- ...oft.ML.OnnxRuntime.FasterRcnnSample.csproj | 2 +- .../ManagedProjections.shared.cs | 3 +- .../NativeMethods.shared.cs | 4 +- .../core/framework/execution_provider.h | 16 - include/onnxruntime/core/graph/graph.h | 32 +- include/onnxruntime/core/graph/graph_viewer.h | 6 - .../core/graph/indexed_sub_graph.h | 6 - .../core/session/onnxruntime_c_api.h | 491 +----------- .../core/session/onnxruntime_cxx_api.h | 261 +------ .../core/session/onnxruntime_cxx_inline.h | 350 +-------- .../onnxruntime_session_options_config_keys.h | 5 +- js/build_webgpu.bat | 79 -- js/common/lib/tensor-impl-type-mapping.ts | 9 +- js/common/lib/tensor-impl.ts | 7 - js/common/package.json | 3 +- js/common/test/unit-tests/common.ts | 5 +- .../test/unit-tests/tensor/constructor-f16.ts | 62 -- .../unit-tests/tensor/constructor-type.ts | 8 + js/web/lib/build-def.d.ts | 7 - js/web/lib/wasm/jsep/backend-webgpu.ts | 28 +- js/web/lib/wasm/jsep/backend-webnn.ts | 3 +- js/web/lib/wasm/jsep/init.ts | 144 ++-- .../ops/3rd-party/conv_backprop_webgpu.ts | 96 +-- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 2 +- .../lib/wasm/jsep/webgpu/program-manager.ts | 1 + js/web/lib/wasm/jsep/webgpu/types.ts | 10 + js/web/lib/wasm/proxy-wrapper.ts | 8 +- js/web/lib/wasm/session-options.ts | 116 +-- js/web/lib/wasm/wasm-core-impl.ts | 97 +-- js/web/lib/wasm/wasm-types.ts | 68 +- js/web/lib/wasm/wasm-utils-import.ts | 50 +- js/web/script/build.ts | 36 +- js/web/test/data/ops/conv-transpose.jsonc | 122 --- js/web/test/e2e/exports/main.js | 11 +- js/web/test/e2e/exports/test.js | 22 - .../contrib_ops/webgpu/bert/bias_add.cc | 80 -- .../contrib_ops/webgpu/bert/bias_add.h | 32 - .../contrib_ops/webgpu/bert/fast_gelu.cc | 4 +- .../webgpu/bert/flash_attention.cc | 4 +- .../webgpu/bert/rotary_embedding.cc | 14 +- .../webgpu/bert/skip_layer_norm.cc | 4 +- .../webgpu/quantization/dp4a_matmul_nbits.cc | 326 -------- .../webgpu/quantization/dp4a_matmul_nbits.h | 56 -- .../webgpu/quantization/matmul_nbits.cc | 322 +++++++- .../webgpu/quantization/matmul_nbits.h | 19 + .../subgroup_matrix_matmul_nbits.cc | 8 +- .../webgpu/webgpu_contrib_kernels.cc | 4 +- .../core/framework/compute_capability.h | 20 - .../core/framework/execution_provider.cc | 1 - .../core/framework/external_data_loader.cc | 7 +- .../core/framework/external_data_loader.h | 2 +- .../core/framework/fallback_cpu_capability.cc | 4 - .../core/framework/fallback_cpu_capability.h | 4 - .../core/framework/graph_partitioner.cc | 248 +++---- .../core/framework/graph_partitioner.h | 9 +- .../core/framework/onnxruntime_typeinfo.cc | 71 +- .../core/framework/onnxruntime_typeinfo.h | 2 +- .../core/framework/session_state_utils.cc | 35 +- .../core/framework/tensor_type_and_shape.cc | 35 +- .../core/framework/tensorprotoutils.cc | 29 +- onnxruntime/core/framework/tensorprotoutils.h | 10 +- onnxruntime/core/graph/graph.cc | 295 +------- .../core/graph/graph_flatbuffers_utils.cc | 14 +- onnxruntime/core/graph/model.cc | 32 +- onnxruntime/core/graph/model.h | 8 +- .../core/graph/model_editor_api_types.h | 47 -- .../core/optimizer/constant_folding.cc | 13 +- onnxruntime/core/optimizer/constant_folding.h | 18 - .../optimizer/graph_optimizer_registry.cc | 49 -- .../core/optimizer/graph_optimizer_registry.h | 77 -- .../constant_folding_dq_node.cc | 26 - .../constant_folding_dq_node.h | 37 - .../selection_and_optimization_func.cc | 99 --- .../selection_and_optimization_func.h | 31 - .../providers/acl/acl_execution_provider.cc | 1 - .../providers/acl/acl_execution_provider.h | 1 - .../providers/cann/cann_execution_provider.cc | 1 - .../providers/cann/cann_execution_provider.h | 1 - .../coreml/coreml_execution_provider.cc | 1 - .../coreml/coreml_execution_provider.h | 1 - .../core/providers/cpu/controlflow/loop.cc | 4 +- .../cpu/quantization/conv_integer.cc | 7 +- .../core/providers/cuda/controlflow/loop.cc | 4 +- .../providers/cuda/cuda_execution_provider.cc | 1 - .../providers/cuda/cuda_execution_provider.h | 1 - .../core/providers/cuda/tensor/upsample.cc | 20 +- .../providers/cuda/tensor/upsample_impl.cu | 94 +-- .../providers/cuda/tensor/upsample_impl.h | 20 +- .../src/ExecutionProvider.cpp | 6 +- .../src/ExecutionProvider.h | 3 - .../providers/dnnl/dnnl_execution_provider.cc | 1 - .../providers/dnnl/dnnl_execution_provider.h | 1 - .../providers/js/js_execution_provider.cc | 1 - .../core/providers/js/js_execution_provider.h | 1 - .../migraphx/migraphx_execution_provider.cc | 1 - .../migraphx/migraphx_execution_provider.h | 1 - .../nnapi_builtin/nnapi_execution_provider.cc | 1 - .../nnapi_builtin/nnapi_execution_provider.h | 1 - .../openvino/backends/basic_backend.cc | 2 +- .../openvino/openvino_execution_provider.cc | 1 - .../openvino/openvino_execution_provider.h | 1 - .../qnn/builder/onnx_ctx_model_helper.cc | 38 +- .../qnn/builder/onnx_ctx_model_helper.h | 7 +- .../qnn/builder/qnn_backend_manager.cc | 2 - .../core/providers/qnn/qnn_allocator.cc | 4 +- .../providers/qnn/qnn_execution_provider.cc | 73 +- .../providers/qnn/qnn_execution_provider.h | 2 - .../core/providers/qnn/shared_context.h | 26 - .../rknpu/rknpu_execution_provider.cc | 1 - .../rknpu/rknpu_execution_provider.h | 1 - .../providers/rocm/rocm_execution_provider.cc | 1 - .../providers/rocm/rocm_execution_provider.h | 1 - .../providers/shared_library/provider_api.h | 1 - .../provider_bridge_provider.cc | 3 +- .../shared_library/provider_interfaces.h | 9 - .../shared_library/provider_wrappedtypes.h | 3 - .../providers/snpe/snpe_execution_provider.cc | 1 - .../providers/snpe/snpe_execution_provider.h | 1 - .../tensorrt/tensorrt_execution_provider.cc | 55 +- .../tensorrt/tensorrt_execution_provider.h | 31 - .../tensorrt_execution_provider_helper.cc | 129 ---- .../vitisai/vitisai_execution_provider.cc | 2 +- .../vitisai/vitisai_execution_provider.h | 1 - .../vsinpu/vsinpu_execution_provider.cc | 1 - .../vsinpu/vsinpu_execution_provider.h | 1 - .../providers/webgpu/external_data_loader.cc | 40 - .../providers/webgpu/external_data_loader.h | 30 - .../core/providers/webgpu/generator/range.cc | 2 +- .../webgpu/math/binary_elementwise_ops.cc | 2 +- .../core/providers/webgpu/math/softmax.cc | 238 ------ .../core/providers/webgpu/math/softmax.h | 54 -- .../webgpu/math/unary_elementwise_ops.cc | 2 +- .../core/providers/webgpu/nn/layer_norm.cc | 6 +- onnxruntime/core/providers/webgpu/program.cc | 20 - onnxruntime/core/providers/webgpu/program.h | 1 - .../core/providers/webgpu/program_manager.cc | 10 +- .../webgpu/reduction/reduction_ops.cc | 168 ----- .../webgpu/reduction/reduction_ops.h | 62 -- .../core/providers/webgpu/shader_helper.cc | 3 + .../core/providers/webgpu/shader_variable.cc | 2 +- .../core/providers/webgpu/tensor/cast.cc | 2 +- .../core/providers/webgpu/tensor/cast.h | 2 +- .../core/providers/webgpu/tensor/concat.cc | 2 +- .../core/providers/webgpu/tensor/expand.cc | 2 +- .../core/providers/webgpu/tensor/gather.cc | 2 +- .../core/providers/webgpu/tensor/pad.cc | 261 ------- .../core/providers/webgpu/tensor/pad.h | 40 - .../providers/webgpu/tensor/resize_impl.cc | 8 +- .../core/providers/webgpu/tensor/split.cc | 6 +- .../core/providers/webgpu/tensor/transpose.cc | 62 +- .../core/providers/webgpu/tensor/transpose.h | 2 - .../core/providers/webgpu/tensor/where.cc | 2 +- .../core/providers/webgpu/webgpu_context.cc | 61 +- .../webgpu/webgpu_execution_provider.cc | 38 +- .../webgpu/webgpu_execution_provider.h | 4 - .../webgpu/webgpu_pix_frame_generator.cc | 4 +- .../webgpu/webgpu_pix_frame_generator.h | 2 +- .../webgpu/webgpu_provider_factory.cc | 6 - .../impl/rotaryEmbedding_op_builder.cc | 14 +- .../providers/webnn/builders/model_builder.cc | 6 +- .../providers/webnn/builders/model_builder.h | 10 +- .../webnn/webnn_execution_provider.cc | 1 - .../webnn/webnn_execution_provider.h | 1 - .../xnnpack/xnnpack_execution_provider.cc | 1 - .../xnnpack/xnnpack_execution_provider.h | 1 - .../core/session/abi_session_options.cc | 17 +- onnxruntime/core/session/api_utils.cc | 25 + onnxruntime/core/session/api_utils.h | 9 + onnxruntime/core/session/custom_ops.cc | 25 +- onnxruntime/core/session/inference_session.cc | 78 +- onnxruntime/core/session/inference_session.h | 35 +- onnxruntime/core/session/model_editor_api.h | 65 -- .../core/session/model_editor_c_api.cc | 358 --------- onnxruntime/core/session/onnxruntime_c_api.cc | 328 ++++---- onnxruntime/core/session/ort_apis.h | 16 - .../core/session/provider_bridge_ort.cc | 23 +- onnxruntime/core/session/utils.cc | 125 ---- onnxruntime/core/session/utils.h | 28 - .../execution_providers/qnn/quant_config.py | 6 +- .../python/tools/quantization/quantize.py | 32 +- .../tools/transformers/models/sam2/README.md | 31 +- .../models/sam2/benchmark_sam2.py | 15 +- .../models/sam2/benchmark_sam2.sh | 310 +++----- .../models/sam2/convert_to_onnx.py | 14 +- .../transformers/models/sam2/image_decoder.py | 2 +- .../transformers/models/sam2/image_encoder.py | 74 +- .../transformers/models/sam2/mask_decoder.py | 2 +- .../models/sam2/prompt_encoder.py | 2 +- .../test/ep_weight_sharing_ctx_gen/main.cc | 247 ------ .../test/framework/inference_session_test.cc | 1 - .../test/framework/session_state_test.cc | 27 +- onnxruntime/test/framework/type_info_test.cc | 26 +- onnxruntime/test/providers/base_tester.cc | 6 +- onnxruntime/test/providers/base_tester.h | 6 +- .../test/providers/cpu/math/softmax_test.cc | 13 +- .../providers/cpu/nn/conv_integer_test.cc | 40 - .../internal_testing_execution_provider.cc | 1 - .../internal_testing_execution_provider.h | 1 - .../test/providers/qnn/qnn_ep_context_test.cc | 267 +++---- .../test/providers/qnn/qnn_test_utils.cc | 7 +- .../quantization/test_get_qdq_config.py | 56 -- .../README.md | 10 +- .../command_args_parser.cc | 47 +- .../command_args_parser.h | 0 onnxruntime/test/qnn_ctx_gen/main.cc | 250 +++++++ .../test_configuration.h | 7 +- .../test/shared_lib/custom_op_utils.cc | 20 - onnxruntime/test/shared_lib/custom_op_utils.h | 67 +- onnxruntime/test/shared_lib/test_inference.cc | 192 ++--- .../test/shared_lib/test_model_builder_api.cc | 701 ------------------ .../test/shared_lib/test_ort_format_models.cc | 14 +- onnxruntime/test/shared_lib/utils.h | 52 -- .../test/testdata/cast_float_to_double.onnx | Bin 136 -> 0 bytes .../my_execution_provider.cc | 2 +- .../my_execution_provider.h | 2 +- onnxruntime/wasm/api.cc | 26 +- onnxruntime/wasm/api.h | 24 +- onnxruntime/wasm/js_post_js.js | 2 + onnxruntime/wasm/js_post_js_64.js | 2 + onnxruntime/wasm/post-webgpu.js | 261 ------- onnxruntime/wasm/pre-async.js | 132 ---- onnxruntime/wasm/pre-jsep.js | 308 +++++--- onnxruntime/wasm/pre.js | 15 +- setup.py | 2 +- tools/ci_build/build.py | 21 +- .../custom-nuget-packaging-pipeline.yml | 142 ---- .../py-package-test-pipeline.yml | 2 - .../azure-pipelines/py-packaging-pipeline.yml | 50 +- .../qnn-ep-nuget-packaging-pipeline.yml | 148 ++-- .../rocm-nuget-packaging-pipeline.yml | 339 +++++++++ .../rocm-publish-nuget-pipeline.yml | 21 + .../stages/nuget-cuda-packaging-stage.yml | 15 +- .../stages/nuget-qnn-packaging-stage.yml | 76 -- .../stages/py-cpu-packaging-stage.yml | 124 ++-- ...acts-package-and-publish-steps-windows.yml | 16 - .../templates/jobs/download_win_openvino.yml | 64 -- .../templates/linux-web-init-and-check.yml | 8 - .../templates/py-linux-qnn.yml | 118 ++- .../azure-pipelines/templates/py-linux.yml | 144 ++-- .../templates/py-package-smoking-test.yml | 13 +- .../templates/py-packaging-linux-test-cpu.yml | 18 +- .../templates/py-win-arm64-qnn.yml | 273 ++++--- .../templates/py-win-arm64ec-qnn.yml | 241 +++--- .../templates/py-win-x64-qnn.yml | 21 +- .../azure-pipelines/templates/qnn-ep-win.yml | 260 +++---- .../templates/react-native-ci.yml | 12 +- .../azure-pipelines/templates/web-ci.yml | 3 + .../azure-pipelines/templates/win-ci.yml | 2 +- .../azure-pipelines/templates/win-web-ci.yml | 12 +- .../templates/win-web-multi-browsers.yml | 12 +- .../templates/windowsai-steps.yml | 2 +- .../win-gpu-webgpu-ci-pipeline.yml | 28 - .../win-openvino-ci-pipeline.yml | 116 --- .../win-qnn-arm64-ci-pipeline.yml | 2 +- .../azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- tools/ci_build/set-trigger-rules.py | 3 +- .../nuget/generate_nuspec_for_custom_nuget.py | 150 ---- tools/python/run_CIs_for_external_pr.py | 1 - tools/python/util/__init__.py | 3 +- tools/python/util/vcpkg_helpers.py | 78 +- winml/adapter/winml_adapter_model.cpp | 18 +- 274 files changed, 3285 insertions(+), 10047 deletions(-) delete mode 100644 js/build_webgpu.bat delete mode 100644 js/common/test/unit-tests/tensor/constructor-f16.ts delete mode 100644 onnxruntime/contrib_ops/webgpu/bert/bias_add.cc delete mode 100644 onnxruntime/contrib_ops/webgpu/bert/bias_add.h delete mode 100644 onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc delete mode 100644 onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h delete mode 100644 onnxruntime/core/graph/model_editor_api_types.h delete mode 100644 onnxruntime/core/optimizer/graph_optimizer_registry.cc delete mode 100644 onnxruntime/core/optimizer/graph_optimizer_registry.h delete mode 100644 onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc delete mode 100644 onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h delete mode 100644 onnxruntime/core/optimizer/selection_and_optimization_func.cc delete mode 100644 onnxruntime/core/optimizer/selection_and_optimization_func.h delete mode 100644 onnxruntime/core/providers/webgpu/external_data_loader.cc delete mode 100644 onnxruntime/core/providers/webgpu/external_data_loader.h delete mode 100644 onnxruntime/core/providers/webgpu/math/softmax.cc delete mode 100644 onnxruntime/core/providers/webgpu/math/softmax.h delete mode 100644 onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc delete mode 100644 onnxruntime/core/providers/webgpu/reduction/reduction_ops.h delete mode 100644 onnxruntime/core/providers/webgpu/tensor/pad.cc delete mode 100644 onnxruntime/core/providers/webgpu/tensor/pad.h create mode 100644 onnxruntime/core/session/api_utils.cc create mode 100644 onnxruntime/core/session/api_utils.h delete mode 100644 onnxruntime/core/session/model_editor_api.h delete mode 100644 onnxruntime/core/session/model_editor_c_api.cc delete mode 100644 onnxruntime/core/session/utils.cc delete mode 100644 onnxruntime/core/session/utils.h delete mode 100644 onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc rename onnxruntime/test/{ep_weight_sharing_ctx_gen => qnn_ctx_gen}/README.md (82%) rename onnxruntime/test/{ep_weight_sharing_ctx_gen => qnn_ctx_gen}/command_args_parser.cc (68%) rename onnxruntime/test/{ep_weight_sharing_ctx_gen => qnn_ctx_gen}/command_args_parser.h (100%) create mode 100644 onnxruntime/test/qnn_ctx_gen/main.cc rename onnxruntime/test/{ep_weight_sharing_ctx_gen => qnn_ctx_gen}/test_configuration.h (75%) delete mode 100644 onnxruntime/test/shared_lib/test_model_builder_api.cc delete mode 100644 onnxruntime/test/testdata/cast_float_to_double.onnx delete mode 100644 onnxruntime/wasm/post-webgpu.js delete mode 100644 onnxruntime/wasm/pre-async.js delete mode 100644 tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml create mode 100644 tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml create mode 100644 tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml delete mode 100644 tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml delete mode 100644 tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml delete mode 100644 tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml delete mode 100644 tools/nuget/generate_nuspec_for_custom_nuget.py diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index a449e42f6bf19..26084ab42ec1c 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -6045,38 +6045,3 @@ https://github.com/intel/neural-speed terms, and open source software license terms. These separate license terms govern your use of the third party programs as set forth in the "THIRD-PARTY-PROGRAMS" file. - -_____ - -dawn - -https://dawn.googlesource.com/dawn - - BSD 3-Clause License - - Copyright 2017-2023 The Dawn & Tint Authors - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - 3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/cmake/deps.txt b/cmake/deps.txt index c7db8ef51505d..d0bab93d3c16f 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -53,6 +53,7 @@ re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cd safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.zip;e49b2b964163d27765a5002d210a2f3c73771835 +utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0c12f53da76d0c31b03b9f0f8ec8f3b4.zip;239063aee4946a9af147b473a4c3da78ba7413b4 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index a477d6edb3a3f..ebf20ab21bbd2 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -107,6 +107,23 @@ if(onnxruntime_USE_MIMALLOC) FetchContent_MakeAvailable(mimalloc) endif() +#Protobuf depends on utf8_range +onnxruntime_fetchcontent_declare( + utf8_range + URL ${DEP_URL_utf8_range} + URL_HASH SHA1=${DEP_SHA1_utf8_range} + EXCLUDE_FROM_ALL + FIND_PACKAGE_ARGS NAMES utf8_range +) + +set(utf8_range_ENABLE_TESTS OFF CACHE BOOL "Build test suite" FORCE) +set(utf8_range_ENABLE_INSTALL OFF CACHE BOOL "Configure installation" FORCE) + +# The next line will generate an error message "fatal: not a git repository", but it is ok. It is from flatbuffers +onnxruntime_fetchcontent_makeavailable(utf8_range) +# protobuf's cmake/utf8_range.cmake has the following line +include_directories(${utf8_range_SOURCE_DIR}) + # Download a protoc binary from Internet if needed if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE AND NOT onnxruntime_USE_VCPKG) # This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually @@ -287,7 +304,7 @@ if(NOT TARGET Boost::mp11) EXCLUDE_FROM_ALL FIND_PACKAGE_ARGS NAMES Boost ) - onnxruntime_fetchcontent_makeavailable(mp11) + onnxruntime_fetchcontent_makeavailable(mp11) if(NOT TARGET Boost::mp11) add_library(Boost::mp11 ALIAS Boost::headers) endif() @@ -425,9 +442,6 @@ target_include_directories(safeint_interface INTERFACE ${safeint_SOURCE_DIR}) # Flatbuffers -if(onnxruntime_USE_VCPKG) - find_package(flatbuffers REQUIRED) -else() # We do not need to build flatc for iOS or Android Cross Compile if (CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(FLATBUFFERS_BUILD_FLATC OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATC" FORCE) @@ -478,7 +492,6 @@ namespace std { using ::getenv; } endif() endif() endif() -endif() # ONNX if (NOT onnxruntime_USE_FULL_PROTOBUF) @@ -659,10 +672,17 @@ if (onnxruntime_USE_WEBGPU) # disable things we don't use set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) + set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE) + set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE) + set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE) + set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE) set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE) set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE) set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE) set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled. @@ -712,29 +732,7 @@ if (onnxruntime_USE_WEBGPU) # # if we need to apply patches in the future, we can uncomment the following line. # # The dawn.patch contains the following changes: - # - # - (public) CMake fix to support Emscripten v4.0.3+ - # This change allows Dawn to find the file "gen_struct_info.py" in the correct location. - # https://dawn-review.googlesource.com/c/dawn/+/225514 - # - # - (public) Fix emwgpu C++ implementation for buffer destroy - # In native implementation, wgpuBufferRelease will trigger the buffer destroy (if refcount decreased to 0). But - # in emwgpu implementation, the buffer destroy won't happen. This change fixes the bug. - # https://dawn-review.googlesource.com/c/dawn/+/226315 - # - # - (private) Allow "external" buffer in emwgpu C++ implementation - # This change allows WGPUBufferImpl to destroy the buffer when the refcount decreased to 0 only for non-external - # buffer. - # "external buffer" means the GPUBuffer instance created in JavaScript and imported to C++ by `importJsBuffer`. - # - # - (private) Remove hard-coded CMAKE_OSX_DEPLOYMENT_TARGET in Dawn's CMake files - # https://github.com/microsoft/onnxruntime/pull/23729 - # - # - (private) Fix external ref count for "external" device in emwgpu C++ implementation - # This change fixes the incorrect external ref count for class WGPUDeviceImpl when used with "external" device. - # "external device" means the GPUDevice instance created in JavaScript and imported to C++ by `importJsDevice`. - # - # + # - https://dawn-review.googlesource.com/c/dawn/+/225514 PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch EXCLUDE_FROM_ALL ) diff --git a/cmake/nuget_helpers.cmake b/cmake/nuget_helpers.cmake index b066d1e9fb50e..22143ac422e9f 100644 --- a/cmake/nuget_helpers.cmake +++ b/cmake/nuget_helpers.cmake @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -cmake_minimum_required(VERSION 3.5) +cmake_minimum_required(VERSION 3.0) # Determines the version of a native nuget package from the root packages.config. # diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index 9c9a25f8ee77e..b1e98a9e0411c 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -36,7 +36,10 @@ elseif(onnxruntime_ENABLE_TRITON) endif() if (onnxruntime_MINIMAL_BUILD) - set(onnxruntime_framework_src_exclude) + set(onnxruntime_framework_src_exclude + "${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.h" + "${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.cc" + ) # custom ops support must be explicitly enabled in a minimal build. exclude if not. if (NOT onnxruntime_MINIMAL_BUILD_CUSTOM_OPS) diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 173c872d4cc06..9d680cd04af10 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -9,7 +9,6 @@ if (onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_optimizer_src_patterns "${ONNXRUNTIME_INCLUDE_DIR}/core/optimizer/graph_transformer.h" "${ONNXRUNTIME_ROOT}/core/optimizer/graph_transformer.cc" - "${ONNXRUNTIME_ROOT}/core/optimizer/graph_optimizer_registry.cc" ) if (onnxruntime_EXTENDED_MINIMAL_BUILD) diff --git a/cmake/onnxruntime_providers_js.cmake b/cmake/onnxruntime_providers_js.cmake index fefbab5082da4..9811eae611463 100644 --- a/cmake/onnxruntime_providers_js.cmake +++ b/cmake/onnxruntime_providers_js.cmake @@ -1,10 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) - message(FATAL_ERROR "JSEP can not be used in a basic minimal build. Please build with '--minimal_build extended'") - endif() - add_compile_definitions(USE_JSEP=1) file(GLOB_RECURSE onnxruntime_providers_js_cc_srcs @@ -22,4 +18,4 @@ onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11 Eigen3::Eigen ) - add_dependencies(onnxruntime_providers_js ${onnxruntime_EXTERNAL_DEPENDENCIES}) + add_dependencies(onnxruntime_providers_js ${onnxruntime_EXTERNAL_DEPENDENCIES}) \ No newline at end of file diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 64b53c2912be0..aee6d2ff7655c 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -1029,7 +1029,7 @@ if (onnxruntime_USE_QNN) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy - $ + $ $/onnxruntime/capi/ ) if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 2c2c59091fae5..3d63285d50e72 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -22,7 +22,6 @@ endif() if (onnxruntime_MINIMAL_BUILD) set(onnxruntime_session_src_exclude "${ONNXRUNTIME_ROOT}/core/session/provider_bridge_ort.cc" - "${ONNXRUNTIME_ROOT}/core/session/model_builder_c_api.cc" ) list(REMOVE_ITEM onnxruntime_session_srcs ${onnxruntime_session_src_exclude}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 87aee2a174fab..0916aeb3dd92c 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -236,14 +236,14 @@ function(AddTest) ) endif() # Set test timeout to 3 hours. - set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 10800) + set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 7200) else() add_test(NAME ${_UT_TARGET} COMMAND ${_UT_TARGET} ${TEST_ARGS} WORKING_DIRECTORY $ ) # Set test timeout to 3 hours. - set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 10800) + set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 7200) endif() endif() endfunction(AddTest) @@ -503,7 +503,6 @@ set (onnxruntime_shared_lib_test_SRC if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc) - list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_builder_api.cc) endif() if(onnxruntime_RUN_ONNX_TESTS) @@ -1289,34 +1288,31 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if(onnxruntime_USE_QNN) #qnn ctx generator - set(ep_weight_sharing_ctx_gen_src_dir ${TEST_SRC_DIR}/ep_weight_sharing_ctx_gen) - set(ep_weight_sharing_ctx_gen_src_patterns - "${ep_weight_sharing_ctx_gen_src_dir}/*.cc" - "${ep_weight_sharing_ctx_gen_src_dir}/*.h") + set(onnxruntime_qnn_ctx_gen_src_dir ${TEST_SRC_DIR}/qnn_ctx_gen) + set(onnxruntime_qnn_ctx_gen_src_patterns + "${onnxruntime_qnn_ctx_gen_src_dir}/*.cc" + "${onnxruntime_qnn_ctx_gen_src_dir}/*.h") - file(GLOB ep_weight_sharing_ctx_gen_src CONFIGURE_DEPENDS - ${ep_weight_sharing_ctx_gen_src_patterns} + file(GLOB onnxruntime_qnn_ctx_gen_src CONFIGURE_DEPENDS + ${onnxruntime_qnn_ctx_gen_src_patterns} ) - onnxruntime_add_executable(ep_weight_sharing_ctx_gen ${ep_weight_sharing_ctx_gen_src}) - target_include_directories(ep_weight_sharing_ctx_gen PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) + onnxruntime_add_executable(onnxruntime_qnn_ctx_gen ${onnxruntime_qnn_ctx_gen_src}) + target_include_directories(onnxruntime_qnn_ctx_gen PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} + ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} + ${CMAKE_CURRENT_BINARY_DIR}) if (WIN32) - target_compile_options(ep_weight_sharing_ctx_gen PRIVATE ${disabled_warnings}) + target_compile_options(onnxruntime_qnn_ctx_gen PRIVATE ${disabled_warnings}) if (NOT DEFINED SYS_PATH_LIB) set(SYS_PATH_LIB shlwapi) endif() endif() - if (onnxruntime_BUILD_SHARED_LIB) - set(ep_weight_sharing_ctx_gen_libs onnxruntime_common onnxruntime ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE}) - target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE ${ep_weight_sharing_ctx_gen_libs}) - if (WIN32) - target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE debug dbghelp advapi32) - endif() - else() - target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE onnxruntime_session ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE}) + if(WIN32) + target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE debug dbghelp advapi32) endif() + target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) - set_target_properties(ep_weight_sharing_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") + set_target_properties(onnxruntime_qnn_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") endif() # shared lib @@ -1363,19 +1359,14 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) - - target_include_directories(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_ROOT}) - if (onnxruntime_USE_CUDA) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) endif() - if (onnxruntime_USE_ROCM) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index f3afaf7033fd1..8106e46ccf580 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -211,14 +211,10 @@ else() target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard) endif() - set(onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre.js") - - set(EXPORTED_FUNCTIONS "_malloc,_free") if (onnxruntime_USE_JSEP) - string(APPEND EXPORTED_FUNCTIONS ",_JsepOutput,_JsepGetNodeName") - endif() - if (onnxruntime_USE_WEBGPU) - string(APPEND EXPORTED_FUNCTIONS ",_wgpuBufferRelease,_wgpuCreateInstance") + set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName") + else() + set(EXPORTED_FUNCTIONS "_malloc,_free") endif() if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) @@ -316,15 +312,13 @@ else() target_compile_options(noexcep_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental) endif() target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js\"" + --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js" ) - list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js") else () set(MAXIMUM_MEMORY "4294967296") target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js.js\"" + --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js" ) - list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js") endif () target_link_options(onnxruntime_webassembly PRIVATE @@ -378,6 +372,7 @@ jsepDownload:_pp_") "SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'" ) endif () + set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) if (onnxruntime_USE_JSEP) # NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU @@ -387,8 +382,10 @@ jsepDownload:_pp_") target_compile_definitions(onnxruntime_webassembly PRIVATE USE_JSEP=1) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" + "SHELL:-s ASYNCIFY=1" + "SHELL:-s ASYNCIFY_STACK_SIZE=65536" ) - list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js") + set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) target_link_options(onnxruntime_webassembly PRIVATE @@ -400,20 +397,6 @@ jsepDownload:_pp_") if (onnxruntime_USE_WEBGPU) target_compile_definitions(onnxruntime_webassembly PRIVATE USE_WEBGPU=1) - target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/post-webgpu.js\"" - ) - list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/post-webgpu.js") - endif() - - if (onnxruntime_USE_JSEP OR onnxruntime_USE_WEBGPU OR onnxruntime_USE_WEBNN) - # if any of the above is enabled, we need to use the asyncify library - target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-async.js\"" - "SHELL:-s ASYNCIFY=1" - "SHELL:-s ASYNCIFY_STACK_SIZE=65536" - ) - list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-async.js") endif() if (onnxruntime_EMSCRIPTEN_SETTINGS) @@ -475,8 +458,6 @@ jsepDownload:_pp_") ) endif() - set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS "${onnxruntime_webassembly_script_deps}") - set(target_name_list ort) if (onnxruntime_ENABLE_TRAINING_APIS) diff --git a/cmake/patches/dawn/dawn.patch b/cmake/patches/dawn/dawn.patch index b578b858eac59..2f85d5ab473b5 100644 --- a/cmake/patches/dawn/dawn.patch +++ b/cmake/patches/dawn/dawn.patch @@ -18,7 +18,7 @@ index 6e8ae37593..633af91eef 100644 @@ -77,9 +77,17 @@ if (${DAWN_ENABLE_EMSCRIPTEN}) "${arg_UNPARSED_ARGUMENTS}") endif() - + + # since Emscripten 4.0.3, file gen_struct_info.py is moved to outside of directory maint. + if (EXISTS "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/gen_struct_info.py") + set(EM_GEN_STRUCT_INFO_SCRIPT "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/gen_struct_info.py") @@ -34,114 +34,3 @@ index 6e8ae37593..633af91eef 100644 -q "${EM_BUILD_GEN_DIR}/struct_info_webgpu.json" "-I=${EM_BUILD_GEN_DIR}/include" -diff --git a/src/emdawnwebgpu/README.md b/src/emdawnwebgpu/README.md -index efd6491cd6..8ebc5d28b6 100644 ---- a/src/emdawnwebgpu/README.md -+++ b/src/emdawnwebgpu/README.md -@@ -56,7 +56,7 @@ Set up the build directory using emcmake - mkdir out/cmake-wasm - cd out/cmake-wasm - --# Make sure the path is to the source checkout of Emscripten, not emsdk's release. -+# If using Emscripten v4.0.2 or lower, make sure the path is to the source checkout of Emscripten, not emsdk's release. - emcmake cmake -GNinja -DDAWN_EMSCRIPTEN_TOOLCHAIN="path/to/emscripten" ../.. - - ninja -diff --git a/third_party/emdawnwebgpu/webgpu.cpp b/third_party/emdawnwebgpu/webgpu.cpp -index f1c5a7d50e..16f2495712 100644 ---- a/third_party/emdawnwebgpu/webgpu.cpp -+++ b/third_party/emdawnwebgpu/webgpu.cpp -@@ -131,7 +131,6 @@ class RefCounted : NonMovable { - bool Release() { - if (mRefCount.fetch_sub(1u, std::memory_order_release) == 1u) { - std::atomic_thread_fence(std::memory_order_acquire); -- emwgpuDelete(this); - return true; - } - return false; -@@ -234,6 +233,7 @@ class Ref { - static void Release(T value) { - if (value != nullptr && value->RefCounted::Release()) { - delete value; -+ emwgpuDelete(value); - } - } - -@@ -641,7 +641,8 @@ struct WGPUAdapterImpl final : public EventSource, public RefCounted { - struct WGPUBufferImpl final : public EventSource, - public RefCountedWithExternalCount { - public: -- WGPUBufferImpl(const EventSource* source, bool mappedAtCreation); -+ WGPUBufferImpl(const EventSource* source, bool mappedAtCreation, bool isExternal); -+ ~WGPUBufferImpl(); - - void Destroy(); - const void* GetConstMappedRange(size_t offset, size_t size); -@@ -671,6 +672,7 @@ struct WGPUBufferImpl final : public EventSource, - }; - MapRequest mPendingMapRequest; - WGPUBufferMapState mMapState; -+ bool mIsExternal; - }; - - struct WGPUQueueImpl final : public EventSource, public RefCounted { -@@ -1164,11 +1166,15 @@ WGPUAdapter emwgpuCreateAdapter(const EventSource* source) { - - WGPUBuffer emwgpuCreateBuffer(const EventSource* source, - bool mappedAtCreation = false) { -- return new WGPUBufferImpl(source, mappedAtCreation); -+ return new WGPUBufferImpl(source, mappedAtCreation, true); - } - - WGPUDevice emwgpuCreateDevice(const EventSource* source, WGPUQueue queue) { -- return new WGPUDeviceImpl(source, queue); -+ // This function is only called from JS via `importJsDevice()`, which -+ // needs to increment the external ref count to fix the behavior. -+ WGPUDeviceImpl* device = new WGPUDeviceImpl(source, queue); -+ device->AddExternalRef(); -+ return device; - } - - WGPUQueue emwgpuCreateQueue(const EventSource* source) { -@@ -1275,15 +1281,22 @@ WGPUAdapterImpl::WGPUAdapterImpl(const EventSource* source) - // WGPUBuffer implementations. - // ---------------------------------------------------------------------------- - --WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, bool mappedAtCreation) -+WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, bool mappedAtCreation, bool isExternal) - : EventSource(source), - mMapState(mappedAtCreation ? WGPUBufferMapState_Mapped -- : WGPUBufferMapState_Unmapped) { -+ : WGPUBufferMapState_Unmapped), -+ mIsExternal(isExternal) { - if (mappedAtCreation) { - mPendingMapRequest = {kNullFutureId, WGPUMapMode_Write}; - } - } - -+WGPUBufferImpl::~WGPUBufferImpl() { -+ if (!mIsExternal) { -+ Destroy(); -+ } -+} -+ - void WGPUBufferImpl::Destroy() { - emwgpuBufferDestroy(this); - AbortPendingMap("Buffer was destroyed before mapping was resolved."); -@@ -1504,6 +1517,7 @@ WGPUFuture WGPUShaderModuleImpl::GetCompilationInfo( - void wgpu##Name##Release(WGPU##Name o) { \ - if (o->Release()) { \ - delete o; \ -+ emwgpuDelete(o); \ - } \ - } - WGPU_OBJECTS(DEFINE_WGPU_DEFAULT_ADDREF_RELEASE) -@@ -1638,7 +1652,7 @@ void wgpuBufferUnmap(WGPUBuffer buffer) { - - WGPUBuffer wgpuDeviceCreateBuffer(WGPUDevice device, - const WGPUBufferDescriptor* descriptor) { -- WGPUBuffer buffer = new WGPUBufferImpl(device, descriptor->mappedAtCreation); -+ WGPUBuffer buffer = new WGPUBufferImpl(device, descriptor->mappedAtCreation, false); - emwgpuDeviceCreateBuffer(device, descriptor, buffer); - return buffer; - } diff --git a/cmake/winml_sdk_helpers.cmake b/cmake/winml_sdk_helpers.cmake index ca657311b7f14..9241fcd060caf 100644 --- a/cmake/winml_sdk_helpers.cmake +++ b/cmake/winml_sdk_helpers.cmake @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -cmake_minimum_required(VERSION 3.5) +cmake_minimum_required(VERSION 3.0) # utility function(convert_forward_slashes_to_back input output) diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj b/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj index b1452a64934c2..f00a08a1a3595 100644 --- a/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj +++ b/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj @@ -8,7 +8,7 @@ - + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs index 8916f11919cfe..13117f23e8ef9 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs @@ -25,7 +25,7 @@ internal class ManagedTypeProjection /// /// /// - /// OrtValue created according to the metadata + /// OrtValye created accoding to the metadata internal static OrtValue CreateProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata) { OrtValue result; @@ -191,3 +191,4 @@ private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata } } } + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index b64a5c3e5a4a2..d628b065ceaa7 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -847,7 +847,7 @@ internal class NativeLib /// Creates an instance of OrtSession with provided parameters /// /// Native OrtEnv instance - /// Byte array corresponding to the model + /// Byte array correspoonding to the model /// Size of the model in bytes /// Native SessionOptions instance /// Native OrtPrepackedWeightsContainer instance @@ -1258,7 +1258,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// /// Native SessionOptions instance /// Name of the initializer - /// Native OrtValue instance + /// Native OrtValue instnce [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options, byte[] /*(const char*)*/ name, diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 2245ff5791feb..c9a15de9ef897 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -20,7 +20,6 @@ struct ComputeCapability; class KernelRegistry; struct KernelCreateInfo; class Node; -class GraphOptimizerRegistry; } // namespace onnxruntime #else #include @@ -130,25 +129,10 @@ class IExecutionProvider { and decide whether a node will be assigned to <*this> execution provider. For kernels registered in a kernel registry, `kernel_lookup` must be used to find a matching kernel for this EP. - - The graph_optimizer_registry is designed for enabling L2+ graph optimizations tailored for EPs. - These optimizations are applied after the graph partitioner assigns ComputeCapability to the EP - and before EP's "Compile" or fusion. - - Steps to use graph_optimizer_registry and create the optimization ComputeCapability: - 1. Lookup Optimizer: The EP calls provider bridge API to lookup pre-defined optimizer by name and get selection function. - - Example: g_host->GetOptimizerByName(optimizer_name, graph_optimizer_registry, selection_func) - 2. Run Selection Function: The EP executes the selection function to obtain the selection ComputeCapability. - - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization. - 3. Create Optimization ComputeCapability: The EP uses the selection ComputeCapability to create the optimization ComputeCapability. - 4. Return ComputeCapability: The EP returns the final ComputeCapability, with nodes_to_optimize set to the optimization ComputeCapability. - - Note: For more detailed implementations of using graph_optimizer_registry, please refer to TensorRT EP. */ virtual std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant = nullptr) const; /** diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 35b568e3f8e28..7798394b045dc 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -27,7 +27,6 @@ #include "core/common/span_utils.h" #include "core/common/status.h" #include "core/common/logging/logging.h" -#include "core/framework/ort_value.h" #include "core/framework/prepacked_weights_container.h" #include "core/graph/onnx_protobuf.h" #include "core/graph/basic_types.h" @@ -40,9 +39,6 @@ #include "core/graph/node_arg.h" #include "core/graph/ort_format_load_options.h" -// Type from Model Editor API in ORT C API so can't be in a namespace -struct OrtGraph; - namespace onnxruntime { class Graph; struct IndexedSubGraph; @@ -767,10 +763,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi */ bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const; - /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. - */ - bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const; - /** Gets all the initializer tensors in this Graph. */ const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; } @@ -1438,16 +1430,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph); - static Status LoadFromModelEditorApiModel(const OrtGraph& api_graph, - const Model& owning_model, - const std::unordered_map& domain_to_version, - IOnnxRuntimeOpSchemaCollectionPtr schema_registry, - bool strict_shape_type_inference, - const logging::Logger& logger, - std::unique_ptr& graph); - - Status UpdateUsingModelEditorApiModel(const OrtModel& api_model); - #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const { return runtime_optimizations_; @@ -1648,8 +1630,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // Implementation for initializer replacement Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external); - template // range-initializer returning std::string - std::vector CreateNodeArgs(const StringRange& names, + std::vector CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, const ArgNameToTypeMap& name_to_type_map); void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const; @@ -1713,8 +1694,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return nodes_[node_index].get(); } - Status LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updating_existing_graph = false); - const Model& owning_model_; // GraphProto to store name, version, initializer. @@ -1729,12 +1708,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi InitializedTensorSet name_to_initial_tensor_; - // Initializers that are external to the Graph. - // e.g. created from existing memory using CreateTensorWithDataAndDeleterAsOrtValue in the ORT API. - // As we need to convert to TensorProto for the optimizers to work and keep the deleter information we store them - // in the Graph instance and retrieve during session state finalization. - std::unordered_map ortvalue_initializers_; - std::unordered_set, std::hash, std::equal_to> sparse_tensor_names_; @@ -1771,7 +1744,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // in some case, a fused sub-graph will happens multiple times in one model, we use a map // to store reusable-schema in lookup. InlinedHashMap> reusable_fused_schema_map_; - #endif // !defined(ORT_MINIMAL_BUILD) // Graph nodes. @@ -1834,7 +1806,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::unordered_map> node_arg_to_consumer_nodes_; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - std::unordered_map domain_to_version_; + const std::unordered_map domain_to_version_; // Model IR version. Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION}; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 6a664d8be9c05..9385e2f092e58 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -193,12 +193,6 @@ class GraphViewer { IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return graph_->GetSchemaRegistry(); } #endif - /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. - */ - bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const { - return graph_->GetOrtValueInitializer(name, value); - } - private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer); GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info); diff --git a/include/onnxruntime/core/graph/indexed_sub_graph.h b/include/onnxruntime/core/graph/indexed_sub_graph.h index 088db79a7e005..e457d3dcad1f1 100644 --- a/include/onnxruntime/core/graph/indexed_sub_graph.h +++ b/include/onnxruntime/core/graph/indexed_sub_graph.h @@ -72,12 +72,6 @@ struct IndexedSubGraph { return meta_def_.get(); } - /** Gets the mutable meta definition needed to represent this subgraph as a FunctionProto. - @returns MetaDef instance if it has been set. nullptr if not. */ - MetaDef* GetMutableMetaDef() { - return meta_def_.get(); - } - // Check if the accounting is enabled for the current EP bool IsAccountingEnabled() const { return resource_accountant != nullptr && diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 098de14bdfd61..47e6389492f30 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -305,10 +305,6 @@ ORT_RUNTIME_CLASS(OpAttr); ORT_RUNTIME_CLASS(Logger); ORT_RUNTIME_CLASS(ShapeInferContext); ORT_RUNTIME_CLASS(LoraAdapter); -ORT_RUNTIME_CLASS(ValueInfo); -ORT_RUNTIME_CLASS(Node); -ORT_RUNTIME_CLASS(Graph); -ORT_RUNTIME_CLASS(Model); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -669,9 +665,6 @@ typedef struct OrtApi OrtApi; struct OrtTrainingApi; typedef struct OrtTrainingApi OrtTrainingApi; -struct OrtModelEditorApi; -typedef struct OrtModelEditorApi OrtModelEditorApi; - /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -854,8 +847,7 @@ struct OrtApi { * * \snippet{doc} snippets.dox OrtStatus Return Value */ - ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, - _In_ const void* model_data, size_t model_data_length, + ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); /** \brief Run the model in an ::OrtSession @@ -1348,8 +1340,6 @@ struct OrtApi { * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. * p_data is owned by caller. ReleaseValue won't release p_data. * - * If you wish to transfer ownership of p_data to ORT use CreateTensorWithDataAndDeleterAsOrtValue. - * * \param[in] info Memory description of where the p_data buffer resides (CPU vs GPU etc). * \param[in] p_data Pointer to the data buffer. * \param[in] p_data_len The number of bytes in the data buffer. @@ -2007,8 +1997,7 @@ struct OrtApi { /** \brief Get the value type from an ::OrtMapTypeInfo * * \param[in] map_type_info - * \param[out] type_info A copy of the OrtTypeInfo for the map value type. - * The user must free this value with ReleaseTypeInfo. + * \param[out] type_info * * \snippet{doc} snippets.dox OrtStatus Return Value */ @@ -2023,8 +2012,7 @@ struct OrtApi { * This is used by WinML to support model reflection APIs. * * \param[in] sequence_type_info - * \param[out] type_info A copy of the OrtTypeInfo for the sequence element type. - * The user must free this value with ReleaseTypeInfo. + * \param[out] type_info * * \snippet{doc} snippets.dox OrtStatus Return Value */ @@ -2899,8 +2887,7 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, - _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /** \brief Create session from memory with prepacked weights container @@ -2923,8 +2910,7 @@ struct OrtApi { */ ORT_API2_STATUS(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, - _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /// @} @@ -4307,8 +4293,8 @@ struct OrtApi { * specific type that is described by the returned ::OrtTypeInfo. * * \param[in] optional_type_info - * \param[out] out A copy of ::OrtTypeInfo for what the optional value could be. - * The user must free this value with ReleaseTypeInfo. + * \param[out] out A pointer to the ::OrtTypeInfo for what the optional value could be. + * it is owned by OrtOptionalTypeInfo instance. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -4800,75 +4786,6 @@ struct OrtApi { */ ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); - - /** \brief Release an OrtValueInfo instance if it was not added to an OrtGraph. - * \since Version 1.21. - */ - ORT_CLASS_RELEASE(ValueInfo); - - /** \brief Release an OrtNode if it was not added to an OrtGraph. - * \since Version 1.21. - */ - ORT_CLASS_RELEASE(Node); - - /** \brief Release an OrtGraph. - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.21. - */ - ORT_CLASS_RELEASE(Graph); - - /** \brief Release an OrtModel. - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.21. - */ - ORT_CLASS_RELEASE(Model); - - /** \brief Get the value name from an OrtValueInfo instance. - * \param[in] value_info The OrtValueInfo instance. - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.21. - */ - ORT_API2_STATUS(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); - - /** \brief Get the type information from an OrtValueInfo instance. - * \param[in] value_info The OrtValueInfo instance. - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.21. - */ - ORT_API2_STATUS(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); - - /** \brief Get the Model Editor API instance - * - * Get the Model Editor API instance to create a new model or augment an existing model. - * - * \return Model Editor API struct - * - * \since Version 1.21. - */ - const OrtModelEditorApi*(ORT_API_CALL* GetModelEditorApi)(); - - /** \brief Create an OrtValue for a Tensor that uses pre-existing memory. - * - * ORT will take ownership of the memory and free it using the provided deleter when no longer in use. - * - * \param[in] deleter OrtAllocator instance that will be used to free the memory. - * Only the OrtAllocator:Info and OrtAllocator::Release functions are required. - * The OrtMemoryInfo returned by OrtAllocator::Info must match the location of p_data. - * \param[in] p_data Pointer to the memory that will be used by the Tensor. ORT will take ownership of the memory. - * \param[in] p_data_len Length of the memory in bytes. - * \param[in] shape Dimensions of the Tensor. All values should be > 0. - * \param[in] shape_len Number of dimensions in the shape array. - * \param[in] type Data type of the Tensor. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, - _In_ void* p_data, size_t p_data_len, - _In_ const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type, - _Outptr_ OrtValue** out); }; /* @@ -4983,400 +4900,6 @@ struct OrtCustomOp { void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); }; -/** - * ORT Model Editor API - */ - -/** - * \brief The OrtModelEditorApi struct provides functions to create or edit an ONNX model. - * - * See onnxruntime/test/shared_lib/test_model_editor_api.cc for example usage. - * - * \since Version 1.21. - */ -struct OrtModelEditorApi { - // Model building/editing requires a full build. We return nullptr from GetModelEditorApi if this is a minimal - // build, so it doesn't matter if there are no function pointers in this struct as a user will never get an - // OrtModelEditorApi instance. We do however need a dummy field to avoid empty struct warning. -#if defined(ORT_MINIMAL_BUILD) - const bool not_defined_in_this_build; -#else - /** \brief Create an OrtTypeInfo instance for a Tensor. - * - * Create an OrtTypeInfo instance for a Tensor to use as graph inputs/outputs with the Model Editor API. - * - * User can release `tensor_info` after creating the OrtTypeInfo. - * - * \param[in] tensor_info Tensor type and shape information. - * \param[out] TypeInfo instance for the tensor. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, - _Outptr_ OrtTypeInfo** type_info); - - /** \brief Create an OrtTypeInfo instance for a SparseTensor. - * - * Create an OrtTypeInfo instance for a SparseTensor to use as graph inputs/outputs with the Model Editor API. - * - * User can release `tensor_info` after creating the OrtTypeInfo. - * - * \param[in] tensor_info SparseTensor type and shape information. - * \param[out] TypeInfo instance for the tensor. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, - _Outptr_ OrtTypeInfo** type_info); - - /** \brief Create an OrtTypeInfo instance for a Map. - * - * Create an OrtTypeInfo instance for a Map to use as graph inputs/outputs with the Model Editor API. - * - * User can release `map_value_type` after creating the OrtTypeInfo. - * - * \param[in] map_key_type Key type for the map. - * \param[in] map_value_type Value type for the map. - * \param[out] TypeInfo instance for the map. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, - _Outptr_ OrtTypeInfo** type_info); - - /** \brief Create an OrtTypeInfo instance for a Sequence. - * - * Create an OrtTypeInfo instance for a Sequence to use as graph inputs/outputs with the Model Editor API. - * - * User can release `sequence_type` after creating the OrtTypeInfo. - * - * \param[in] sequence_type Sequence type and shape information. - * \param[out] TypeInfo instance for the sequence. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Outptr_ OrtTypeInfo** type_info); - - /** \brief Create an OrtTypeInfo instance for an Optional. - * - * Create an OrtTypeInfo instance for an Optional to use as graph inputs/outputs with the Model Editor API. - * - * User can release `contained_type` after creating the OrtTypeInfo. - * - * \param[in] tensor_info Tensor type and shape information. - * \param[out] TypeInfo instance for the tensor. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Outptr_ OrtTypeInfo** type_info); - - /** \brief Create an OrtValueInfo for use as an OrtGraph input or output. - * - * \param[in] name The name of the input or output. - * \param[in] type_info The type information for the input or output. The provided value is copied. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, - _Outptr_ OrtValueInfo** value_info); - - /** \brief Create an OrtNode to add to an OrtGraph. - * - * Create an OrtNode. - * - * Create attributes with CreateOpAttr. OrtOpAttr instances are copied. - * - * \param[in] operator_name The name of the operator. - * \param[in] domain_name The domain of the operator. Use an empty string for ONNX operators. - * \param[in] node_name The name of the node. - * \param[in] input_names The names of the inputs. - * \param[in] input_names_len The number of input names. - * \param[in] output_names The names of the outputs. - * \param[in] output_names_len The number of output names. - * \param[in] attributes The optional attributes of the node. - * \param[in] attribs_len The number of attributes. May be zero. - * \param[out] node The OrtNode instance. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, _In_ const char* domain_name, _In_ const char* node_name, - _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, - _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, - _In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len, - _Outptr_ OrtNode** node); - - /** \brief Create an OrtGraph - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateGraph, _Outptr_ OrtGraph** graph); - - /** \brief Set the inputs for the OrtGraph. - * - * Set the graph inputs. This will replace any existing inputs with the new values. - * The OrtGraph takes ownership of the OrtValueInfo instances and you should NOT call ReleaseOrtValueInfo. - * - * \param[in] graph The OrtGraph instance to update. - * \param[in] inputs The input OrtValueInfo instances. - * \param[in] inputs_len The number of input OrtValueInfo instances. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(SetGraphInputs, _Inout_ OrtGraph* graph, - _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); - - /** \brief Set the outputs for the OrtGraph. - * - * Set the graph outputs. This will replace any existing outputs with the new values. - * The OrtGraph takes ownership of the OrtValueInfo instances provided and you should NOT call ReleaseOrtValueInfo. - * - * \param[in] graph The OrtGraph instance to update. - * \param[in] outputs The output OrtValueInfo instances. - * \param[in] outputs_len The number of output OrtValueInfo instances. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(SetGraphOutputs, _Inout_ OrtGraph* graph, - _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); - - /** \brief Add an initializer to the OrtGraph - * - * ORT will take ownership of the OrtValue and you should NOT call ReleaseOrtValue. - * - * Two options: - * - * Allocated memory: - * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. - * Set `data_is_external` to false. - * - * Pre-existing memory: - * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue - * with a tensor that contains a pointer to the existing data. - * Set `data_is_external` to true. - * - * The pointer must remain valid for the duration of the inference session. - * If using CreateTensorWithDataAsOrtValue you are responsible for freeing the memory after the inference session - * is released. - * If using CreateTensorWithDataAndDeleterAsOrtValue, ORT will free the memory using the provided deleter as - * soon as the OrtValue is no longer in use. - * - * NOTE: A tensor containing pre-existing memory MUST have 128 bytes of data or more. - * For smaller tensors use CreateTensorAsOrtValue. - * - * ONNX shape inferencing does not support external data. An initializer involved in shape inferencing is - * typically small (a single value or limited by the rank of a tensor) and uses less than 128 bytes of - * memory, so this limit acts as a simple catch-all rule to avoid issues. - * e.g. Reshape's `shape`, Clip's `min` and `max`, various ops `axes`. - * - * \param[in] graph The OrtGraph instance to update. - * \param[in] name The value name for the initializer. - * \param[in] tensor The OrtValue instance containing the tensor data. - * \param[in] data_is_external Set to true if the data is external and should not be copied. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(AddInitializerToGraph, _Inout_ OrtGraph* graph, _In_ const char* name, _In_ OrtValue* tensor, - bool data_is_external); - - /** \brief Add an OrtNode to an OrtGraph - * - * Add the node to the graph. The OrtGraph will take ownership of OrtNode and you should NOT call ReleaseOrtNode. - * - * \param[in] graph The OrtGraph instance to update. - * \param[in] node The OrtNode instance to add to the graph. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(AddNodeToGraph, _Inout_ OrtGraph* graph, _In_ OrtNode* node); - - /** \brief Create an OrtModel. - * - * Create an OrtModel. - * - * This can be used to build a new model, or to augment an existing model. - * - * \param[in] domain_names The domain names for the model. - * If augmenting an existing model add additional domains if needed. - * \param[in] opset_versions The opset versions for the model. - * If augmenting an existing model add additional opset versions if needed. - * \param[in] opset_entries_len The number of domain_names and opset_versions entries. - * Domain and opset entries should be 1:1 - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateModel, - _In_reads_(opset_entries_len) const char* const* domain_names, - _In_reads_(opset_entries_len) const int* opset_versions, - size_t opset_entries_len, - _Outptr_ OrtModel** model); - - /** \brief Add an OrtGraph to an OrtModel. - * - * Add the graph to a model. This should be called once when creating a new model. - * - * The OrtModel takes ownership of the OrtGraph and you should NOT call ReleaseOrtGraph. - * - * \param[in] model The OrtModel instance to update. - * \param[in] graph The OrtGraph instance to add to the model. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(AddGraphToModel, _Inout_ OrtModel* model, _In_ OrtGraph* graph); - - /** \brief Create an OrtSession using the OrtModel. - * - * Create an inference session using the OrtModel instance. - * The OrtModel should have been populated with an OrtGraph containing nodes and initializers, and SetGraphInputs - * and SetGraphOutputs must have been called. - * This will validate the model, run optimizers, and prepare the session for inferencing. - * - * ReleaseOrtModel must be called to free the OrtModel after session creation. - * - * \param[in] env The OrtEnv instance. - * \param[in] model The OrtModel instance. - * \param[in] options The OrtSessionOptions instance. - * \param[out] out The OrtSession instance. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, - _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); - - /** \brief Create an OrtSession to augment an existing model. - * - * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. - * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the - * model is finalized. - * - * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. - * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. - * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made - * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. - * - * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the - * session for inferencing by calling FinalizeModelEditorSession. - * - * \param{in} env The OrtEnv instance. - * \param{in} model_path The path to the existing ONNX model to augment. - * \param{in} options The OrtSessionOptions instance. - * \param{out} out The created OrtSession instance. - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateModelEditorSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, - _Outptr_ OrtSession** out); - - /** \brief Create an OrtSession to augment an existing model. - * - * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. - * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the - * model is finalized. - * - * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. - * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. - * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made - * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. - * - * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the - * session for inferencing by calling FinalizeModelEditorSession. - * - * \param{in} env The OrtEnv instance. - * \param{in} model_data The model data for the existing model to augment. - * \param{in} model_data_length The length of the model data. - * \param{in} options The OrtSessionOptions instance. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, - _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, - _Outptr_ OrtSession** out); - - /** \brief Query the session for the opset version of a domain. - * - * When using the Model Editor API to augment a model, any new nodes must conform to the opset version of the - * original model. To do that the user must be able to discover that opset version. - * - * \param[in] session OrtSession to query - * \param[in] domain Domain to query. The ONNX domain is an empty string. - * \param[out] opset The opset version of the domain. - * - * \snippet{doc} snippets.dox OrtStatus Return Value. Returns an error if the domain is not used in the model. - * - * \since Version 1.21. - */ - ORT_API2_STATUS(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, _Out_ int* opset); - - /** \brief Apply changes to augment the ONNX model in a session created using CreateModelEditorSession[FromArray] - * - * Adds new nodes and updates graph inputs/outputs using `model` to augment the original ONNX model in the session. - * All changes will be validated. - * Call FinalizeModelEditorSession to prepare the session for inferencing. - * - * Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel. - * i.e. you don't need to call SetGraphInputs/SetGraphOutputs if they are unchanged. - * - * ReleaseOrtModel must be called to free the OrtModel after it is applied to the session. - * - * \param[in] session OrtSession to update. Session must have been created using CreateModelEditorSession[FromArray]. - * \param[in] model OrtModel containing new nodes, new initializers, and updated graph input and/or output info. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(ApplyModelToModelEditorSession, _Inout_ OrtSession* session, _In_ OrtModel* model); - - /** \brief Finalize the Model Editor session that was created using CreateModelEditorSession[FromArray]. - * - * Finalize the Model Editor session that augmented an ONNX model by adding new nodes. - * This will run optimizers and prepare the session for inferencing. - * - * \param[in] session OrtSession to finalize. Session must have been created using CreateModelEditorSession[FromArray]. - * \param[in] options OrtSessionOptions to use for the session. - * \param[in] Optional prepacked_weights_container OrtPrepackedWeightsContainer to use for the session. - Set to nullptr if not used. - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.21. - */ - ORT_API2_STATUS(FinalizeModelEditorSession, _Inout_ OrtSession* session, _In_ const OrtSessionOptions* options, - _In_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container); -#endif // !defined(ORT_MINIMAL_BUILD) -}; - /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 979b478e2fbb4..123ef98901003 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -26,17 +26,16 @@ #include "onnxruntime_c_api.h" #include "onnxruntime_float16.h" -#include #include #include +#include #include #include #include -#include +#include #include #include -#include -#include +#include #ifdef ORT_NO_EXCEPTIONS #include @@ -121,7 +120,7 @@ const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); #endif #endif -/// This returns a reference to the ORT C API. +/// This returns a reference to the OrtApi interface in use inline const OrtApi& GetApi() noexcept { return *Global::api_; } /// @@ -144,20 +143,6 @@ std::string GetBuildInfoString(); /// vector of strings std::vector GetAvailableProviders(); -/// -/// This returns a reference to the ORT C Model Editor API. Used if building or augmenting a model at runtime. -/// -/// ORT C Model Editor API reference -inline const OrtModelEditorApi& GetModelEditorApi() { - auto* api = GetApi().GetModelEditorApi(); - if (api == nullptr) { - // minimal build - ORT_CXX_API_THROW("Model Editor API is not available in this build", ORT_FAIL); - } - - return *api; -} - /** \brief IEEE 754 half-precision floating point data type * * \details This struct is used for converting float to float16 and back @@ -538,10 +523,6 @@ ORT_DEFINE_RELEASE(Status); ORT_DEFINE_RELEASE(OpAttr); ORT_DEFINE_RELEASE(Op); ORT_DEFINE_RELEASE(KernelInfo); -ORT_DEFINE_RELEASE(ValueInfo); -ORT_DEFINE_RELEASE(Node); -ORT_DEFINE_RELEASE(Graph); -ORT_DEFINE_RELEASE(Model); #undef ORT_DEFINE_RELEASE @@ -578,9 +559,7 @@ struct Base { constexpr Base() = default; constexpr explicit Base(contained_type* p) noexcept : p_{p} {} - ~Base() { - OrtRelease(p_); - } + ~Base() { OrtRelease(p_); } Base(const Base&) = delete; Base& operator=(const Base&) = delete; @@ -656,13 +635,9 @@ struct AllocatedFree { struct AllocatorWithDefaultOptions; struct Env; -struct Graph; -struct Model; -struct Node; -struct ModelMetadata; struct TypeInfo; struct Value; -struct ValueInfo; +struct ModelMetadata; /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators * and release them at the end of the scope. The lifespan of the given allocator @@ -1076,10 +1051,6 @@ struct ConstSessionImpl : Base { size_t GetOutputCount() const; ///< Returns the number of model outputs size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden - std::vector GetInputNames() const; - std::vector GetOutputNames() const; - std::vector GetOverridableInitializerNames() const; - /** \brief Returns a copy of input name at the specified index. * * \param index must less than the value returned by GetInputCount() @@ -1113,12 +1084,6 @@ struct ConstSessionImpl : Base { TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo - - int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain - - // Will move before checkin if that's the case. - std::vector GetInputs() const; - std::vector GetOutputs() const; }; template @@ -1196,9 +1161,6 @@ struct SessionImpl : ConstSessionImpl { * \param[in] kv_len Number of elements in the keys and values arrays */ void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len); - - void FinalizeModelEditorSession(const Model& model, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); }; } // namespace detail @@ -1210,34 +1172,13 @@ using UnownedSession = detail::SessionImpl>; * */ struct Session : detail::SessionImpl { - /// Create an empty Session object, must be assigned a valid one to be used. Wraps OrtApi::CreateSession - explicit Session(std::nullptr_t) {} - explicit Session(OrtSession* p) : SessionImpl{p} {} ///< C API Interop - - Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); - - /// Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer + explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); - - /// Wraps OrtApi::CreateSessionFromArray - Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); - - /// Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer + OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); - -#if !defined(ORT_MINIMAL_BUILD) - /// Wraps OrtModelEditorApi::CreateSessionFromModel - Session(const Env& env, const Model& model, const SessionOptions& options); - - /// Wraps OrtModelEditorApi::CreateModelEditorSession - static Session CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); - - /// Wraps OrtModelEditorApi::CreateModelEditorSession - static Session CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length, - const SessionOptions& options); -#endif // !defined(ORT_MINIMAL_BUILD) + OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer ConstSession GetConst() const { return ConstSession{this->p_}; } UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } @@ -1269,7 +1210,7 @@ using ConstMemoryInfo = detail::MemoryInfoImpl { static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created - explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API + explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C Api MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } }; @@ -1292,7 +1233,6 @@ struct TensorTypeAndShapeInfoImpl : Base { [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions - std::vector GetSymbolicDimensions() const; std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape }; @@ -1308,18 +1248,8 @@ struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl; using Base::Base; - /// Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used - explicit TensorTypeAndShapeInfo(std::nullptr_t) {} - /// Used for interop with the C API - explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} - - // Create a TensorTypeAndShapeInfo object with the specified element type and dimensions - // symbolic_dims are optional, but should be 1:1 with dims. - // The value in symbolic_dims will be used for all entries in dims that are -1. - explicit TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, - const std::vector& dims, - const std::vector* symbolic_dims = nullptr); - + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used + explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } }; @@ -1414,18 +1344,9 @@ struct TypeInfo : detail::TypeInfoImpl { using Base = detail::TypeInfoImpl; using Base::Base; - /// Create an empty TypeInfo object, must be assigned a valid one to be used - explicit TypeInfo(std::nullptr_t) {} + explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop -#if !defined(ORT_MINIMAL_BUILD) - static TypeInfo CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_info); - static TypeInfo CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_info); - static TypeInfo CreateSequenceTypeInfo(ConstTypeInfo sequence_type); - static TypeInfo CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type); - static TypeInfo CreateOptionalTypeInfo(ConstTypeInfo contained_type); -#endif // !defined(ORT_MINIMAL_BUILD) - ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; } }; @@ -1780,8 +1701,7 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. */ template - static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, - const int64_t* shape, size_t shape_len); + static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. * @@ -1792,25 +1712,11 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ - static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, - const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type); - - /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAndDeleterAsOrtValue. - * - * \param deleter OrtAllocator that will be used to free the buffer when no longer required. - * \param p_data Pointer to the data buffer. - * \param p_data_byte_count The number of bytes in the data buffer. - * \param shape Pointer to the tensor shape dimensions. - * \param shape_len The number of tensor shape dimensions. - * \param type The data type. - */ - static Value CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, - const int64_t* shape, size_t shape_len, + static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. - * This overload will allocate the buffer for the tensor according to the supplied shape and data type. + * This overload will allocate the buffer for the tensor according to the supplied shape and data type. * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. * The input data would need to be copied into the allocated buffer. * This API is not suitable for strings. @@ -1834,8 +1740,7 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type); + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a Map Onnx type representation. * The API would ref-count the supplied OrtValues and they will be released @@ -2532,9 +2437,6 @@ struct CustomOpBase : OrtCustomOp { return std::vector{}; } - // Ort::CustomOpBase derived class should provide the following static method with the type/shape inferencing - // implementation if needed: - // static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context) template decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) { OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { @@ -2557,129 +2459,6 @@ struct CustomOpBase : OrtCustomOp { int end_ver_ = MAX_CUSTOM_OP_END_VER; }; -namespace detail { -template -struct ValueInfoImpl : Ort::detail::Base { - using B = Ort::detail::Base; - using B::B; - - std::string Name() const; - ConstTypeInfo TypeInfo() const; -}; -} // namespace detail - -// Const object holder that does not own the underlying object -using ConstValueInfo = detail::ValueInfoImpl>; - -/** \brief Wrapper around ::OrtValueInfo - * - */ -struct ValueInfo : detail::ValueInfoImpl { - explicit ValueInfo(std::nullptr_t) {} ///< No instance is created - /// Take ownership of a pointer created by C API - explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl{p} {} - - // Create ValueInfo for a tensor - explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info); - - ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; } -}; - -namespace detail { -template -struct NodeImpl : Ort::detail::Base { - using B = Ort::detail::Base; - using B::B; -}; -} // namespace detail - -/** \brief Wrapper around ::OrtNode - * - */ -struct Node : detail::NodeImpl { - explicit Node(std::nullptr_t) {} ///< No instance is created - explicit Node(OrtNode* p) : NodeImpl{p} {} ///< Take ownership of a pointer created by C API - -#if !defined(ORT_MINIMAL_BUILD) - Node(const std::string& operator_name, const std::string& operator_domain, - const std::string& node_name, - const std::vector& input_names, - const std::vector& output_names); - - /// - /// Wraps CreateNode. Node takes ownership of attributes on success and updates the OpAttr in `attributes` to do so. - /// - Node(const std::string& operator_name, const std::string& operator_domain, - const std::string& node_name, - const std::vector& input_names, - const std::vector& output_names, - std::vector& attributes); - - private: - static void Init(const std::string& operator_name, const std::string& operator_domain, - const std::string& node_name, - const std::vector& input_names, - const std::vector& output_names, - std::vector& attributes, - OrtNode*& node); -#endif // !defined(ORT_MINIMAL_BUILD) -}; - -namespace detail { -template -struct GraphImpl : Ort::detail::Base { - using B = Ort::detail::Base; - using B::B; - -#if !defined(ORT_MINIMAL_BUILD) - void SetInputs(std::vector& inputs); - void SetOutputs(std::vector& outputs); - void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value - void AddNode(Node& node); // Graph takes ownership of Node -#endif // !defined(ORT_MINIMAL_BUILD) -}; -} // namespace detail - -/** \brief Wrapper around ::OrtGraph - * - */ -struct Graph : detail::GraphImpl { - explicit Graph(std::nullptr_t) {} ///< No instance is created - explicit Graph(OrtGraph* p) : GraphImpl{p} {} ///< Take ownership of a pointer created by C API -#if !defined(ORT_MINIMAL_BUILD) - Graph(); -#endif -}; - -namespace detail { -template -struct ModelImpl : Ort::detail::Base { - using B = Ort::detail::Base; - using B::B; - -#if !defined(ORT_MINIMAL_BUILD) - void AddGraph(Graph& graph); -#endif -}; -} // namespace detail - -// Const object holder that does not own the underlying object -using ConstModel = detail::ModelImpl>; - -/** \brief Wrapper around ::OrtModel - * - */ -struct Model : detail::ModelImpl { - using DomainOpsetPair = std::pair; - - explicit Model(std::nullptr_t) {} ///< No instance is created - explicit Model(OrtModel* p) : ModelImpl{p} {} ///< Take ownership of a pointer created by C API - -#if !defined(ORT_MINIMAL_BUILD) - explicit Model(const std::vector& opsets); -#endif - - ConstModel GetConst() const { return ConstModel{this->p_}; } -}; } // namespace Ort + #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 48c5e52e33c53..3aeb9412f350e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -10,9 +10,7 @@ #include #include #include -#include #include -#include // Convert OrtStatus to Ort::Status and return // instead of throwing @@ -997,59 +995,6 @@ inline size_t ConstSessionImpl::GetOverridableInitializerCount() const { return out; } -template -inline std::vector ConstSessionImpl::GetInputNames() const { - AllocatorWithDefaultOptions allocator; - - auto num_inputs = GetInputCount(); - std::vector input_names; - input_names.reserve(num_inputs); - - for (size_t i = 0; i < num_inputs; ++i) { - char* name = nullptr; - ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name)); - input_names.push_back(name); - allocator.Free(name); - } - - return input_names; -} - -template -inline std::vector ConstSessionImpl::GetOutputNames() const { - AllocatorWithDefaultOptions allocator; - - auto num_inputs = GetOutputCount(); - std::vector output_names; - output_names.reserve(num_inputs); - - for (size_t i = 0; i < num_inputs; ++i) { - char* name = nullptr; - ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name)); - output_names.push_back(name); - allocator.Free(name); - } - - return output_names; -} - -template -inline std::vector ConstSessionImpl::GetOverridableInitializerNames() const { - AllocatorWithDefaultOptions allocator; - - auto num_initializers = GetOverridableInitializerCount(); - std::vector initializer_names; - initializer_names.reserve(num_initializers); - - for (size_t i = 0; i < num_initializers; ++i) { - char* name = nullptr; - ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name)); - initializer_names.push_back(name); - } - - return initializer_names; -} - template inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { char* out; @@ -1106,45 +1051,6 @@ inline TypeInfo ConstSessionImpl::GetOverridableInitializerTypeInfo(size_t in return TypeInfo{out}; } -#if !defined(ORT_MINIMAL_BUILD) -template -inline int ConstSessionImpl::GetOpset(const std::string& domain) const { - int opset; - ThrowOnError(GetModelEditorApi().SessionGetOpsetForDomain(this->p_, domain.c_str(), &opset)); - return opset; -} -#endif // !defined(ORT_MINIMAL_BUILD) - -template -std::vector ConstSessionImpl::GetInputs() const { - const std::vector input_names = GetInputNames(); - - std::vector inputs; - inputs.reserve(input_names.size()); - - for (size_t i = 0; i < input_names.size(); ++i) { - auto type_info = GetInputTypeInfo(i); - inputs.emplace_back(ValueInfo{input_names[i], type_info.GetConst()}); - } - - return inputs; -} - -template -std::vector ConstSessionImpl::GetOutputs() const { - const std::vector output_names = GetOutputNames(); - - std::vector outputs; - outputs.reserve(output_names.size()); - - for (size_t i = 0; i < output_names.size(); ++i) { - auto type_info = GetOutputTypeInfo(i); - outputs.emplace_back(ValueInfo{output_names[i], type_info.GetConst()}); - } - - return outputs; -} - template inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, size_t output_count) { @@ -1192,15 +1098,6 @@ inline void SessionImpl::SetEpDynamicOptions(const char* const* keys, const c ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len)); } -#if !defined(ORT_MINIMAL_BUILD) -template -inline void SessionImpl::FinalizeModelEditorSession(const Model& model, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container) { - ThrowOnError(GetModelEditorApi().ApplyModelToModelEditorSession(this->p_, model)); - ThrowOnError(GetModelEditorApi().FinalizeModelEditorSession(this->p_, options, prepacked_weights_container)); -} -#endif // #if !defined(ORT_MINIMAL_BUILD) - } // namespace detail inline SessionOptions::SessionOptions() { @@ -1247,32 +1144,6 @@ inline Session::Session(const Env& env, const void* model_data, size_t model_dat prepacked_weights_container, &this->p_)); } -#if !defined(ORT_MINIMAL_BUILD) -inline Session::Session(const Env& env, const Model& model, const SessionOptions& options) { - ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_)); -} - -// static -inline Session Session::CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path, - const SessionOptions& options) { - OrtSession* session = nullptr; - ThrowOnError(GetModelEditorApi().CreateModelEditorSession(env, model_path, options, &session)); - return Session(session); -} - -// static -inline Session Session::CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length, - const SessionOptions& options) { - OrtSession* session = nullptr; - ThrowOnError(GetModelEditorApi().CreateModelEditorSessionFromArray(env, model_data, model_data_length, options, - &session)); - return Session(session); -} - -void FinalizeModelEditorSession(const Model& model, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); -#endif // #if !defined(ORT_MINIMAL_BUILD) - inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); @@ -1340,59 +1211,6 @@ inline int64_t ModelMetadata::GetVersion() const { return out; } -inline TensorTypeAndShapeInfo::TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, - const std::vector& dims, - const std::vector* symbolic_dims) { - ThrowOnError(GetApi().CreateTensorTypeAndShapeInfo(&p_)); - ThrowOnError(GetApi().SetTensorElementType(p_, element_type)); - ThrowOnError(GetApi().SetDimensions(p_, dims.data(), dims.size())); - - if (symbolic_dims) { - std::vector symbolic_dims_cstr; - symbolic_dims_cstr.reserve(symbolic_dims->size()); - std::transform(symbolic_dims->begin(), symbolic_dims->end(), std::back_inserter(symbolic_dims_cstr), - [](const std::string& s) { return s.c_str(); }); - ThrowOnError(GetApi().SetSymbolicDimensions(p_, symbolic_dims_cstr.data(), symbolic_dims_cstr.size())); - } -} - -#if !defined(ORT_MINIMAL_BUILD) -// static -inline TypeInfo TypeInfo::CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_type_and_shape_info) { - OrtTypeInfo* output = nullptr; - ThrowOnError(GetModelEditorApi().CreateTensorTypeInfo(tensor_type_and_shape_info, &output)); - return TypeInfo{output}; -} - -// static -inline TypeInfo TypeInfo::CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_type_and_shape_info) { - OrtTypeInfo* output = nullptr; - ThrowOnError(GetModelEditorApi().CreateSparseTensorTypeInfo(sparse_tensor_type_and_shape_info, &output)); - return TypeInfo{output}; -} - -// static -inline TypeInfo TypeInfo::CreateSequenceTypeInfo(ConstTypeInfo sequence_type) { - OrtTypeInfo* output; - ThrowOnError(GetModelEditorApi().CreateSequenceTypeInfo(sequence_type, &output)); - return TypeInfo{output}; -} - -// static -inline TypeInfo TypeInfo::CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type) { - OrtTypeInfo* output; - ThrowOnError(GetModelEditorApi().CreateMapTypeInfo(key_type, value_type, &output)); - return TypeInfo{output}; -} - -// static -inline TypeInfo TypeInfo::CreateOptionalTypeInfo(ConstTypeInfo contained_type) { - OrtTypeInfo* output; - ThrowOnError(GetModelEditorApi().CreateOptionalTypeInfo(contained_type, &output)); - return TypeInfo{output}; -} -#endif // #if !defined(ORT_MINIMAL_BUILD) - namespace detail { template @@ -1426,16 +1244,9 @@ inline void TensorTypeAndShapeInfoImpl::GetSymbolicDimensions(const char** va ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count)); } -template -inline std::vector TensorTypeAndShapeInfoImpl::GetSymbolicDimensions() const { - std::vector out(GetDimensionsCount(), nullptr); - ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, out.data(), out.size())); - return out; -} - template inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { - std::vector out(GetDimensionsCount(), -1); + std::vector out(GetDimensionsCount(), 0); ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size())); return out; } @@ -1749,35 +1560,23 @@ void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_inf } // namespace detail template -inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, - const int64_t* shape, size_t shape_len) { +inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); } -inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, - const int64_t* shape, size_t shape_len, +inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); return Value{out}; } -inline Value Value::CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, - const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type) { - OrtValue* out; - ThrowOnError(GetApi().CreateTensorWithDataAndDeleterAsOrtValue(deleter, p_data, p_data_byte_count, - shape, shape_len, type, &out)); - return Value{out}; -} - template inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); } -inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type) { +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); return Value{out}; @@ -1795,8 +1594,7 @@ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& values_shape, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, - values_shape.shape, values_shape.shape_len, type, - &out)); + values_shape.shape, values_shape.shape_len, type, &out)); return Value{out}; } @@ -2369,142 +2167,4 @@ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) con return attr_hdl; } -namespace detail { -inline std::vector StringsToCharPtrs(const std::vector& strings) { - std::vector ptrs; - ptrs.reserve(strings.size()); - std::transform(strings.begin(), strings.end(), std::back_inserter(ptrs), - [](const std::string& s) { return s.c_str(); }); - - return ptrs; -} -} // namespace detail - -#if !defined(ORT_MINIMAL_BUILD) -// static -inline void Node::Init(const std::string& operator_name, const std::string& operator_domain, - const std::string& node_name, - const std::vector& input_names, - const std::vector& output_names, - std::vector& attributes, - OrtNode*& node) { - auto inputs = detail::StringsToCharPtrs(input_names); - auto outputs = detail::StringsToCharPtrs(output_names); - - std::vector attributes_ptrs; - attributes_ptrs.reserve(attributes.size()); - std::transform(attributes.begin(), attributes.end(), std::back_inserter(attributes_ptrs), - [](OpAttr& attr) -> OrtOpAttr* { return attr; }); - - ThrowOnError(GetModelEditorApi().CreateNode(operator_name.c_str(), operator_domain.c_str(), node_name.c_str(), - inputs.data(), inputs.size(), - outputs.data(), outputs.size(), - attributes_ptrs.data(), attributes_ptrs.size(), - &node)); - - // Node now owns the attributes - std::for_each(attributes.begin(), attributes.end(), [](OpAttr& attr) { attr.release(); }); -} - -inline Node::Node(const std::string& operator_name, const std::string& operator_domain, - const std::string& node_name, - const std::vector& input_names, - const std::vector& output_names, - std::vector& attributes) { - Init(operator_name, operator_domain, node_name, input_names, output_names, attributes, p_); -} - -inline Node::Node(const std::string& operator_name, const std::string& operator_domain, - const std::string& node_name, - const std::vector& input_names, - const std::vector& output_names) { - std::vector empty_attributes; - Init(operator_name, operator_domain, node_name, input_names, output_names, empty_attributes, p_); -} - -inline Graph::Graph() { - ThrowOnError(GetModelEditorApi().CreateGraph(&p_)); -} - -inline Model::Model(const std::vector& opsets) { - std::vector domains; - std::vector versions; - domains.reserve(opsets.size()); - versions.reserve(opsets.size()); - - for (const auto& pair : opsets) { - domains.push_back(pair.first.c_str()); - versions.push_back(pair.second); - } - - ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); -} - -inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { - ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_)); -} -#endif // !defined(ORT_MINIMAL_BUILD) - -namespace detail { -template <> -inline std::string ValueInfoImpl::Name() const { - const char* name = nullptr; - ThrowOnError(GetApi().GetValueInfoName(this->p_, &name)); - return name; -} - -template <> -inline ConstTypeInfo ValueInfoImpl::TypeInfo() const { - const OrtTypeInfo* type_info = nullptr; - ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info)); - return ConstTypeInfo{type_info}; -} - -#if !defined(ORT_MINIMAL_BUILD) -template <> -inline void GraphImpl::SetInputs(std::vector& inputs) { - std::vector inputs_ptrs; - inputs_ptrs.reserve(inputs.size()); - std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs), - [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); - - ThrowOnError(GetModelEditorApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size())); - - // Graph now owns the inputs - std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); }); -} - -template <> -inline void GraphImpl::SetOutputs(std::vector& outputs) { - std::vector outputs_ptrs; - outputs_ptrs.reserve(outputs.size()); - std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs), - [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); - - ThrowOnError(GetModelEditorApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size())); - - // Graph now owns the outputs - std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); }); -} - -template <> -inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { - // Graph takes ownership of `initializer` - ThrowOnError(GetModelEditorApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external)); -} - -template <> -inline void GraphImpl::AddNode(Node& node) { - // Graph takes ownership of `node` - ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release())); -} - -template <> -inline void ModelImpl::AddGraph(Graph& graph) { - // Model takes ownership of `graph` - ThrowOnError(GetModelEditorApi().AddGraphToModel(p_, graph.release())); -} -#endif // !defined(ORT_MINIMAL_BUILD) - -} // namespace detail } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index af1f9c04b2831..117a2cdabca2f 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -315,12 +315,9 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed // in case user need to merge/connect multiple EPContext nodes in one model static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix"; -// Share EP related resources across sessions +// Share EP related resources across EPs static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts"; -// Stop to share EP related resources across sessions from then on -static const char* const kOrtSessionOptionStopShareEpContexts = "ep.stop_share_ep_contexts"; - // Use this config when dumping EP context model with an external initializers file // All initializers will be inside the external data file if specified, otherwise all in Onnx file static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFileName = diff --git a/js/build_webgpu.bat b/js/build_webgpu.bat deleted file mode 100644 index 95413509e701d..0000000000000 --- a/js/build_webgpu.bat +++ /dev/null @@ -1,79 +0,0 @@ -@echo off - -rem build_webgpu.bat --- build onnxruntime-web with WebGPU EP -rem -rem Usage: -rem build_webgpu.bat config [clean] -rem -rem Options: -rem config Build configuration, "d" or "r" -rem clean Perform a clean build, "clean" or empty - -setlocal enabledelayedexpansion - -set ROOT=%~dp0..\ -set BUILD_DIR=%ROOT%build_webgpu - -:arg1 -if ["%~1"]==["d"] ( - set CONFIG=Debug - set CONFIG_EXTRA_FLAG= - @rem --enable_wasm_profiling --wasm_run_tests_in_browser - @rem --cmake_extra_defines onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL=1 - @rem --enable_wasm_debug_info - goto :arg2 -) -if ["%~1"]==["r"] ( - set CONFIG=Release - set CONFIG_EXTRA_FLAG= - @rem --enable_wasm_api_exception_catching --disable_rtti - goto :arg2 -) -echo Invalid configuration "%~1", must be "d"(Debug) or "r"(Release) -exit /b 1 - -:arg2 -if ["%~2"]==["clean"] ( - goto :clean -) -if not exist "%ROOT%js\web\dist" ( - goto :npm_ci -) - -goto :build_wasm - -:clean -if exist "%BUILD_DIR%" ( - rd /s /q %BUILD_DIR% -) - -pushd %ROOT% -git submodule sync --recursive -git submodule update --init --recursive -popd - -:npm_ci -pushd %ROOT%js -call npm ci -popd -pushd %ROOT%js\common -call npm ci -popd -pushd %ROOT%js\web -call npm ci -call npm run pull:wasm -popd - -:build_wasm - -set PATH=C:\Program Files\Git\usr\bin;%PATH% - -call %ROOT%build.bat --config %CONFIG% %CONFIG_EXTRA_FLAG% --skip_submodule_sync --build_wasm --target onnxruntime_webassembly --skip_tests^ - --enable_wasm_simd --enable_wasm_threads --use_jsep --use_webnn --use_webgpu --build_dir %BUILD_DIR% - -IF NOT "%ERRORLEVEL%" == "0" ( - exit /b %ERRORLEVEL% -) - -copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.wasm %ROOT%js\web\dist\ -copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.mjs %ROOT%js\web\dist\ diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index 58f4cc6281b09..14dbdca707220 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -44,6 +44,12 @@ export const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map { isTypedArrayChecked = true; const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from; const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from; - - // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any - const Float16Array = (globalThis as any).Float16Array; const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from; if (isBigInt64ArrayAvailable) { diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 2c54bdbfb6874..8feb8d7205fa1 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -261,13 +261,6 @@ export class Tensor implements TensorInterface { } else { throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`); } - } else if (arg0 === 'float16' && arg1 instanceof Uint16Array && typedArrayConstructor !== Uint16Array) { - // when Float16Array is available and data is of type Uint16Array. - // We allow Uint16Array to be passed in as data for 'float16' tensor until Float16Array is generally - // supported in JavaScript environment. - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - data = new (globalThis as any).Float16Array(arg1.buffer, arg1.byteOffset, arg1.length); } else { throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`); } diff --git a/js/common/package.json b/js/common/package.json index 2d331bb42e4c7..3d8d3f6533cfe 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -15,8 +15,7 @@ "build": "node ./build.js", "prepare": "npm run build", "pretest": "tsc --build ./test", - "test": "mocha \"./test/**/*.js\" --timeout 30000", - "test:f16": "mocha -n js-float16array \"./test/**/*.js\" --timeout 30000" + "test": "mocha ./test/**/*.js --timeout 30000" }, "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/common/test/unit-tests/common.ts b/js/common/test/unit-tests/common.ts index bbbceed605bd4..0a6e4e5dd6ebd 100644 --- a/js/common/test/unit-tests/common.ts +++ b/js/common/test/unit-tests/common.ts @@ -29,10 +29,9 @@ export const NUMBER_COMPATIBLE_NUMERICAL_TYPES = [ export const BIGINT_TYPES = [['int64', BigInt64Array, true] as const, ['uint64', BigUint64Array, true] as const]; /** - * float16 type, data represented by Uint16Array/Float16Array + * float16 type, data represented by Uint16Array */ -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export const FLOAT16_TYPE = ['float16', (globalThis as any).Float16Array ?? Uint16Array, false] as const; +export const FLOAT16_TYPE = ['float16', Uint16Array, false] as const; /** * A list of all numerical types. diff --git a/js/common/test/unit-tests/tensor/constructor-f16.ts b/js/common/test/unit-tests/tensor/constructor-f16.ts deleted file mode 100644 index 38c6ac037c5f9..0000000000000 --- a/js/common/test/unit-tests/tensor/constructor-f16.ts +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import assert from 'assert/strict'; -import { Tensor } from 'onnxruntime-common'; - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -const globalF16 = (globalThis as any).Float16Array; - -(globalF16 ? describe : describe.skip)('Tensor Constructor Tests - check type float16 (Float16Array available)', () => { - it("[float16] new Tensor('float16', numbers, dims): allow number array when Float16Array is available", () => { - const tensor = new Tensor('float16', [1, 2, 3, 4], [2, 2]); - assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); - assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'"); - assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1'); - assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2'); - assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3'); - assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4'); - assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4'); - }); - - it("[float16] new Tensor('float16', float16array, dims): allow Float16Array when Float16Array is available", () => { - const tensor = new Tensor('float16', new globalF16([1, 2, 3, 4]), [2, 2]); - assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); - assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'"); - assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1'); - assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2'); - assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3'); - assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4'); - assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4'); - }); - - it("[float16] new Tensor('float16', uint16array, dims): allow Uint16Array when Float16Array is available", () => { - const tensor = new Tensor('float16', new Uint16Array([15360, 16384, 16896, 17408]), [2, 2]); - assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); - assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'"); - assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1'); - assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2'); - assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3'); - assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4'); - assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4'); - }); -}); - -(globalF16 ? describe.skip : describe)( - 'Tensor Constructor Tests - check type float16 (Float16Array not available)', - () => { - it( - "[float16] new Tensor('float16', numbers, dims): " + - "expect to throw because it's not allowed to construct 'float16' tensor from number array", - () => { - assert.throws(() => new Tensor('float16', [1, 2, 3, 4], [2, 2]), TypeError); - }, - ); - - it("[float16] new Tensor('float16', uint16array, dims): allow Uint16Array", () => { - const tensor = new Tensor('float16', new Uint16Array([15360, 16384, 16896, 17408]), [2, 2]); - assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); - assert(tensor.data instanceof Uint16Array, "tensor.data should be an instance of 'Uint16Array'"); - }); - }, -); diff --git a/js/common/test/unit-tests/tensor/constructor-type.ts b/js/common/test/unit-tests/tensor/constructor-type.ts index d86e18ba744b8..02390800e8611 100644 --- a/js/common/test/unit-tests/tensor/constructor-type.ts +++ b/js/common/test/unit-tests/tensor/constructor-type.ts @@ -105,6 +105,14 @@ describe('Tensor Constructor Tests - check types', () => { assert(tensor.data instanceof Uint8Array, "tensor.data should be an instance of 'Uint8Array'"); }); + it( + "[float16] new Tensor('float16', numbers, dims): " + + "expect to throw because it's not allowed to construct 'float16' tensor from number array", + () => { + assert.throws(() => new Tensor('float16', [1, 2, 3, 4], [2, 2]), TypeError); + }, + ); + it("[badtype] new Tensor('a', numbers, dims): expect to throw because 'a' is an invalid type", () => { assert.throws(() => new TensorAny('a', [1, 2, 3, 4], [2, 2]), TypeError); }); diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 83a52ebaefe05..59f64a3179605 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -40,13 +40,6 @@ interface BuildDefinitions { */ readonly ENABLE_BUNDLE_WASM_JS: boolean; - /** - * defines whether to use WebGPU EP instead of JSEP for WebGPU backend. - * - * This flag requires the corresponding WebAssembly artifact to be built with `--use_webgpu` flag. - */ - readonly USE_WEBGPU_EP: boolean; - // #endregion // #region Build definitions for ESM diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 413e89111740e..a0010df4643a4 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -13,6 +13,7 @@ import { ProgramManager } from './webgpu/program-manager'; import { AdapterInfo, ComputeContext, + DeviceInfo, GpuArchitecture, GpuData, GpuVendor, @@ -134,6 +135,26 @@ class AdapterInfoImpl implements AdapterInfo { } } +class DeviceInfoImpl implements DeviceInfo { + readonly subgroupsSupported: boolean; + readonly subgroupsF16Supported: boolean; + readonly subgroupSizeRange?: readonly [number, number]; + + constructor(device: GPUDevice) { + this.subgroupsSupported = device.features.has('subgroups' as GPUFeatureName); + this.subgroupsF16Supported = device.features.has('subgroups' as GPUFeatureName); + // Currently subgroups feature is still experimental and size attributes are not in the WebGPU IDL, so we have to + // workaround the IDL type checks. + // TODO: clean this after subgroups feature is settled in IDL. + const deviceSubgroupsLimits = device.limits as { minSubgroupSize?: number; maxSubgroupSize?: number }; + if (!this.subgroupsSupported || !deviceSubgroupsLimits.minSubgroupSize || !deviceSubgroupsLimits.maxSubgroupSize) { + this.subgroupSizeRange = undefined; + } else { + this.subgroupSizeRange = [deviceSubgroupsLimits.minSubgroupSize, deviceSubgroupsLimits.maxSubgroupSize]; + } + } +} + /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. @@ -141,6 +162,7 @@ class AdapterInfoImpl implements AdapterInfo { export class WebGpuBackend { adapterInfo: AdapterInfoImpl; device: GPUDevice; + deviceInfo: DeviceInfoImpl; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping */ @@ -252,9 +274,13 @@ export class WebGpuBackend { } requireFeatureIfAvailable('shader-f16'); // Try subgroups - requireFeatureIfAvailable('subgroups' as GPUFeatureName); + if (requireFeatureIfAvailable('subgroups' as GPUFeatureName)) { + // If subgroups feature is available, also try subgroups-f16 + requireFeatureIfAvailable('subgroups-f16' as GPUFeatureName); + } this.device = await adapter.requestDevice(deviceDescriptor); + this.deviceInfo = new DeviceInfoImpl(this.device); this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo())); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 55784ae13ad7a..2b9a9208e2e53 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -314,8 +314,7 @@ export class WebNNBackend { bufferView = new Float32Array(buffer); break; case 'float16': - bufferView = - typeof Float16Array !== 'undefined' && Float16Array.from ? new Float16Array(buffer) : new Uint16Array(buffer); + bufferView = new Uint16Array(buffer); break; case 'int32': bufferView = new Int32Array(buffer); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 8ab6b054bf8a7..b4071eae51c8f 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -1,17 +1,23 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import type { Env } from 'onnxruntime-common'; +import { Env } from 'onnxruntime-common'; import { calculateTensorSizeInBytes, DataType } from '../wasm-common'; import type { OrtWasmModule } from '../wasm-types'; -import type { WebGpuBackend } from './backend-webgpu'; +import { WebGpuBackend } from './backend-webgpu'; import { LOG_DEBUG } from './log'; -import type { TensorView } from './tensor-view'; +import { TensorView } from './tensor-view'; import { ShapeUtil } from './util'; -import type { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types'; +import { + AdapterInfo, + ComputeContext, + ComputeContextInputsOutputsMapping, + DeviceInfo, + ProgramInfo, +} from './webgpu/types'; import { WebNNBackend } from './backend-webnn'; /* eslint-disable no-bitwise */ @@ -70,6 +76,7 @@ class TensorViewImpl implements TensorView { class ComputeContextImpl implements ComputeContext { readonly adapterInfo: AdapterInfo; + readonly deviceInfo: DeviceInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -87,6 +94,7 @@ class ComputeContextImpl implements ComputeContext { contextDataOffset: number, ) { this.adapterInfo = backend.adapterInfo; + this.deviceInfo = backend.deviceInfo; // extract context data const ptrSize = module.PTR_SIZE; @@ -197,83 +205,79 @@ export const init = async ( } if (name === 'webgpu') { - if (!BUILD_DEFS.USE_WEBGPU_EP) { - // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires - const webGpuBackendImpl = require('./backend-webgpu').WebGpuBackend; - const backend = new webGpuBackendImpl(); - await backend.initialize(env, gpuAdapter!); + const backend = new WebGpuBackend(); + await backend.initialize(env, gpuAdapter!); - jsepInit('webgpu', [ - // backend - backend, - - // jsepAlloc() - (size: number) => backend.alloc(Number(size)), + jsepInit('webgpu', [ + // backend + backend, - // jsepFree() - (ptr: number) => backend.free(ptr), + // jsepAlloc() + (size: number) => backend.alloc(Number(size)), - // jsepCopy(src, dst, size, isSourceGpu) - (src: number, dst: number, size: number, isSourceGpu = false) => { - if (isSourceGpu) { - LOG_DEBUG( - 'verbose', - () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, - ); - backend.memcpy(Number(src), Number(dst)); - } else { - LOG_DEBUG( - 'verbose', - () => - `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, - ); - const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); - backend.upload(Number(dst), data); - } - }, + // jsepFree() + (ptr: number) => backend.free(ptr), - // jsepCopyAsync(src, dst, size) - async (gpuDataId: number, dataOffset: number, size: number): Promise => { + // jsepCopy(src, dst, size, isSourceGpu) + (src: number, dst: number, size: number, isSourceGpu = false) => { + if (isSourceGpu) { LOG_DEBUG( 'verbose', - () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, + () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, ); - - await backend.download(Number(gpuDataId), () => - module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), - ); - }, - - // jsepCreateKernel - (kernelType: string, kernelId: number, attribute: unknown) => - backend.createKernel( - kernelType, - Number(kernelId), - attribute, - module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), - ), - - // jsepReleaseKernel - (kernel: number) => backend.releaseKernel(kernel), - - // jsepRun - (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { + backend.memcpy(Number(src), Number(dst)); + } else { LOG_DEBUG( 'verbose', () => - `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, + `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, ); - const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); - return backend.computeKernel(Number(kernel), context, errors); - }, - // jsepCaptureBegin - () => backend.captureBegin(), - // jsepCaptureEnd - () => backend.captureEnd(), - // jsepReplay - () => backend.replay(), - ]); - } + const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); + backend.upload(Number(dst), data); + } + }, + + // jsepCopyAsync(src, dst, size) + async (gpuDataId: number, dataOffset: number, size: number): Promise => { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, + ); + + await backend.download(Number(gpuDataId), () => + module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), + ); + }, + + // jsepCreateKernel + (kernelType: string, kernelId: number, attribute: unknown) => + backend.createKernel( + kernelType, + Number(kernelId), + attribute, + module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), + ), + + // jsepReleaseKernel + (kernel: number) => backend.releaseKernel(kernel), + + // jsepRun + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { + LOG_DEBUG( + 'verbose', + () => + `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, + ); + const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); + return backend.computeKernel(Number(kernel), context, errors); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay(), + ]); } else { const backend = new WebNNBackend(env); jsepInit('webnn', [ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 50620cea33863..ad1de42106d6d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -46,11 +46,6 @@ export const createConvTranspose2DProgramInfo = ( const inputChannelsPerGroup = wShape[2] / group; const outputChannelsPerGroup = wShape[3]; const aComponents = isChannelsLast ? getMaxComponents(inputChannelsPerGroup) : 1; - const packInputAs4 = isChannelsLast && outputChannelsPerGroup === 1 && inputChannelsPerGroup >= 4; - const inputChannelsPerGroupInt = packInputAs4 - ? Math.floor(inputChannelsPerGroup / 4) * 4 - : Math.floor(inputChannelsPerGroup / aComponents) * aComponents; - const inputChannelsRemainder = inputChannelsPerGroup - inputChannelsPerGroupInt; const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1; const bComponents = isChannelsLast ? (outputChannelsPerGroup === 1 ? aComponents : components) : 1; const outputSize = ShapeUtil.size(outputShape) / components; @@ -83,7 +78,6 @@ export const createConvTranspose2DProgramInfo = ( { type: DataType.uint32, data: dilations }, { type: DataType.uint32, data: effectiveFilterDims }, { type: DataType.int32, data: pads }, - { type: DataType.uint32, data: inputChannelsPerGroupInt }, { type: DataType.uint32, data: inputChannelsPerGroup }, { type: DataType.uint32, data: outputChannelsPerGroup }, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims), @@ -102,7 +96,6 @@ export const createConvTranspose2DProgramInfo = ( { name: 'dilations', type: 'u32', length: filterDims.length }, { name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length }, { name: 'pads', type: 'i32', length: pads.length }, - { name: 'input_channels_per_group_int', type: 'u32' }, { name: 'input_channels_per_group', type: 'u32' }, { name: 'output_channels_per_group', type: 'u32' }, ]; @@ -121,40 +114,16 @@ export const createConvTranspose2DProgramInfo = ( const calculateResult = (): string => { let calcStr = ''; - if (packInputAs4) { - if (aComponents === 4) { - calcStr += ` - let xValue = ${dy.getByOffset('x_offset')}; - let wValue = ${w.getByOffset('w_offset')}; - dotProd = dotProd + dot(xValue, wValue); - x_offset += 1u; - w_offset += 1u;`; - } else if (aComponents === 2) { - calcStr += ` - dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')})); - x_offset += 2u; - w_offset += 2u;`; - } else if (aComponents === 1) { - calcStr += ` - dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}, ${dy.getByOffset('x_offset + 2u')}, ${dy.getByOffset('x_offset + 3u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}, ${w.getByOffset('w_offset + 2u')}, ${w.getByOffset('w_offset + 3u')})); - x_offset += 4u; - w_offset += 4u;`; - } - } else { + if (aComponents === 1) { calcStr += ` - let xValue = ${ - isChannelsLast - ? dy.getByOffset( - `${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`, - ) - : dy.get('batch', 'inputChannel', 'idyR', 'idyC') - }; - `; - if (aComponents === 1) { + let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)}; + let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)}; + dotProd = dotProd + xValue * wValue;`; + } else { + if (outputChannelsPerGroup === 1) { calcStr += ` - let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)}; - let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)}; - dotProd = dotProd + xValue * wValue;`; + let wValue = ${w.getByOffset(`${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)} / ${bComponents}`)}; + dotProd = dotProd + dot(xValue, wValue);`; } else { for (let c = 0; c < aComponents; c++) { calcStr += ` @@ -165,32 +134,6 @@ export const createConvTranspose2DProgramInfo = ( } return calcStr; }; - const calculateRemainder = (): string => { - if (inputChannelsRemainder === 0) { - return ''; - } - if (!packInputAs4) { - throw new Error(`packInputAs4 ${packInputAs4} is not true.`); - } - let calcStr = ''; - if (aComponents === 1) { - calcStr += 'dotProd = dotProd'; - for (let i = 0; i < inputChannelsRemainder; i++) { - calcStr += ` - + ${dy.getByOffset(`x_offset + ${i}`)} * ${w.getByOffset(`w_offset + ${i}`)}`; - } - calcStr += ';'; - } else if (aComponents === 2) { - if (inputChannelsRemainder !== 2) { - throw new Error(`Invalid inputChannelsRemainder ${inputChannelsRemainder}.`); - } - calcStr += ` - let xValue = ${dy.getByOffset('x_offset')}; - let wValue = ${w.getByOffset('w_offset')}; - dotProd = dotProd + dot(xValue, wValue);`; - } - return calcStr; - }; const codeSnippet = ` let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)}; let batch = ${output.indicesGet('outputIndices', 0)}; @@ -226,6 +169,7 @@ export const createConvTranspose2DProgramInfo = ( // Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0 wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner); } + for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { if (wC % uniforms.dilations.y != 0) { continue; @@ -238,19 +182,17 @@ export const createConvTranspose2DProgramInfo = ( } let idyC: u32 = u32(dyC); var inputChannel = groupId * uniforms.input_channels_per_group; - ${ - packInputAs4 - ? ` - var x_offset = ${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}; - var w_offset = ${w.indicesToOffset(`${w.type.indices}(wRPerm, wCPerm, inputChannel, wOutChannel)`)} / ${bComponents}; - ` - : '' - } - for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group_int; d2 = d2 + ${packInputAs4 ? 4 : aComponents}) { + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + ${aComponents}) { + let xValue = ${ + isChannelsLast + ? dy.getByOffset( + `${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`, + ) + : dy.get('batch', 'inputChannel', 'idyR', 'idyC') + }; ${calculateResult()} - inputChannel = inputChannel + ${packInputAs4 ? 4 : aComponents}; + inputChannel = inputChannel + ${aComponents}; } - ${calculateRemainder()} wC = wC + uniforms.strides.y - 1; } wR = wR + uniforms.strides[0] - 1; @@ -269,7 +211,7 @@ export const createConvTranspose2DProgramInfo = ( return { name: 'ConvTranspose2D', shaderCache: { - hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${packInputAs4}${inputChannelsRemainder}`, + hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${outputChannelsPerGroup === 1}`, inputDependencies, }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 6a8dffb73fa08..6a78c8ae3b190 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -433,7 +433,7 @@ const createInPlaceSoftmaxProgramInfo = ( getShaderSource, getRunData: () => ({ outputs: [], - dispatchGroup: { x: 1, y: sequenceLength, z: batchSize * numHeads }, + dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads }, programUniforms, }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 18d505f57655a..2c5180c5db3ee 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -99,6 +99,7 @@ export class ProgramManager { const extensionsInfo: Array<{ feature: GPUFeatureName; extension: string }> = [ { feature: 'shader-f16', extension: 'f16' }, { feature: 'subgroups' as GPUFeatureName, extension: 'subgroups' }, + { feature: 'subgroups-f16' as GPUFeatureName, extension: 'subgroups_f16' }, ]; extensionsInfo.forEach((info) => { if (device.features.has(info.feature)) { diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index f3cfc6cb98cae..9321ac170d036 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -21,6 +21,11 @@ export interface AdapterInfo { isArchitecture: (architecture: GpuArchitecture) => boolean; isVendor: (vendor: GpuVendor) => boolean; } +export interface DeviceInfo { + readonly subgroupsSupported: boolean; + readonly subgroupsF16Supported: boolean; + readonly subgroupSizeRange?: readonly [number, number]; +} export interface GpuData { type: GpuDataType; @@ -160,6 +165,11 @@ export interface ComputeContext { */ readonly adapterInfo: AdapterInfo; + /** + * gpu device info + */ + readonly deviceInfo: DeviceInfo; + /** * stores the pointer to OpKernelContext */ diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 30b1f5101e5f2..5d97bb83e3475 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -12,11 +12,7 @@ import { } from './proxy-messages'; import * as core from './wasm-core-impl'; import { initializeWebAssembly } from './wasm-factory'; -import { - importProxyWorker, - inferWasmPathPrefixFromScriptSrc, - isEsmImportMetaUrlHardcodedAsFileUri, -} from './wasm-utils-import'; +import { importProxyWorker, inferWasmPathPrefixFromScriptSrc } from './wasm-utils-import'; const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined'; let proxyWorker: Worker | undefined; @@ -120,7 +116,7 @@ export const initializeWebAssemblyAndOrtRuntime = async (): Promise => { BUILD_DEFS.IS_ESM && BUILD_DEFS.ENABLE_BUNDLE_WASM_JS && !message.in!.wasm.wasmPaths && - (objectUrl || isEsmImportMetaUrlHardcodedAsFileUri) + (objectUrl || BUILD_DEFS.ESM_IMPORT_META_URL?.startsWith('file:')) ) { // for a build bundled the wasm JS, if either of the following conditions is met: // - the proxy worker is loaded from a blob URL diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 89a4484e5a1c4..17e564247863d 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import type { InferenceSession } from 'onnxruntime-common'; +import { InferenceSession } from 'onnxruntime-common'; import { getInstance } from './wasm-factory'; import { allocWasmString, checkLastError, iterateExtraOptions } from './wasm-utils'; @@ -54,28 +54,13 @@ const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => } }; -const appendSessionConfig = (sessionOptionsHandle: number, key: string, value: string, allocs: number[]): void => { - const keyDataOffset = allocWasmString(key, allocs); - const valueDataOffset = allocWasmString(value, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: ${key} - ${value}.`); - } -}; - -const appendEpOption = (epOptions: Array<[number, number]>, key: string, value: string, allocs: number[]): void => { - const keyDataOffset = allocWasmString(key, allocs); - const valueDataOffset = allocWasmString(value, allocs); - epOptions.push([keyDataOffset, valueDataOffset]); -}; - -const setExecutionProviders = async ( +const setExecutionProviders = ( sessionOptionsHandle: number, executionProviders: readonly InferenceSession.ExecutionProviderConfig[], allocs: number[], -): Promise => { +): void => { for (const ep of executionProviders) { let epName = typeof ep === 'string' ? ep : ep.name; - const epOptions: Array<[number, number]> = []; // check EP name switch (epName) { @@ -86,44 +71,26 @@ const setExecutionProviders = async ( // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; if (deviceType) { - appendSessionConfig(sessionOptionsHandle, 'deviceType', deviceType, allocs); + const keyDataOffset = allocWasmString('deviceType', allocs); + const valueDataOffset = allocWasmString(deviceType, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); + } } } break; case 'webgpu': - if (BUILD_DEFS.USE_WEBGPU_EP) { - epName = 'WebGPU'; - let customDevice: GPUDevice | undefined; - - if (typeof ep !== 'string') { - const customOptions = ep as unknown as { device: GPUDevice }; - if (customOptions.device) { - if (typeof GPUDevice !== 'undefined' && customOptions.device instanceof GPUDevice) { - customDevice = customOptions.device; - } else { - throw new Error('Invalid GPU device set in WebGPU EP options.'); - } + epName = 'JS'; + if (typeof ep !== 'string') { + const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; + if (webgpuOptions?.preferredLayout) { + if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { + throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); } - - // TODO: handle more options - } - - const info = getInstance().webgpuRegisterDevice!(customDevice); - if (info) { - const [deviceId, instanceHandle, deviceHandle] = info; - appendEpOption(epOptions, 'deviceId', deviceId.toString(), allocs); - appendEpOption(epOptions, 'webgpuInstance', instanceHandle.toString(), allocs); - appendEpOption(epOptions, 'webgpuDevice', deviceHandle.toString(), allocs); - } - } else { - epName = 'JS'; - if (typeof ep !== 'string') { - const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; - if (webgpuOptions?.preferredLayout) { - if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { - throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); - } - appendSessionConfig(sessionOptionsHandle, 'preferredLayout', webgpuOptions.preferredLayout, allocs); + const keyDataOffset = allocWasmString('preferredLayout', allocs); + const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError(`Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); } } } @@ -136,34 +103,13 @@ const setExecutionProviders = async ( } const epNameDataOffset = allocWasmString(epName, allocs); - const epOptionsCount = epOptions.length; - let keysOffset = 0; - let valuesOffset = 0; - if (epOptionsCount > 0) { - keysOffset = getInstance()._malloc(epOptionsCount * getInstance().PTR_SIZE); - allocs.push(keysOffset); - valuesOffset = getInstance()._malloc(epOptionsCount * getInstance().PTR_SIZE); - allocs.push(valuesOffset); - for (let i = 0; i < epOptionsCount; i++) { - getInstance().setValue(keysOffset + i * getInstance().PTR_SIZE, epOptions[i][0], '*'); - getInstance().setValue(valuesOffset + i * getInstance().PTR_SIZE, epOptions[i][1], '*'); - } - } - if ( - (await getInstance()._OrtAppendExecutionProvider( - sessionOptionsHandle, - epNameDataOffset, - keysOffset, - valuesOffset, - epOptionsCount, - )) !== 0 - ) { + if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) { checkLastError(`Can't append execution provider: ${epName}.`); } } }; -export const setSessionOptions = async (options?: InferenceSession.SessionOptions): Promise<[number, number[]]> => { +export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => { const wasm = getInstance(); let sessionOptionsHandle = 0; const allocs: number[] = []; @@ -209,19 +155,20 @@ export const setSessionOptions = async (options?: InferenceSession.SessionOption } if (sessionOptions.executionProviders) { - await setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); + setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } if (sessionOptions.enableGraphCapture !== undefined) { if (typeof sessionOptions.enableGraphCapture !== 'boolean') { throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); } - appendSessionConfig( - sessionOptionsHandle, - 'enableGraphCapture', - sessionOptions.enableGraphCapture.toString(), - allocs, - ); + const keyDataOffset = allocWasmString('enableGraphCapture', allocs); + const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); + if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError( + `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`, + ); + } } if (sessionOptions.freeDimensionOverrides) { @@ -241,7 +188,12 @@ export const setSessionOptions = async (options?: InferenceSession.SessionOption if (sessionOptions.extra !== undefined) { iterateExtraOptions(sessionOptions.extra, '', new WeakSet>(), (key, value) => { - appendSessionConfig(sessionOptionsHandle, key, value, allocs); + const keyDataOffset = allocWasmString(key, allocs); + const valueDataOffset = allocWasmString(value, allocs); + + if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError(`Can't set a session config entry: ${key} - ${value}.`); + } }); } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index dbcf80adf3552..4bccfa76fdda3 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -102,20 +102,11 @@ export const initRuntime = async (env: Env): Promise => { * @param epName */ export const initEp = async (env: Env, epName: string): Promise => { - // initialize ASYNCIFY support - getInstance().asyncInit?.(); - - if (epName === 'webgpu' && BUILD_DEFS.USE_WEBGPU_EP) { - getInstance().webgpuInit!((device) => { - env.webgpu.device = device; - }); - } - if (!BUILD_DEFS.DISABLE_JSEP) { // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires const initJsep = require('./jsep/init').init; - if (epName === 'webgpu' && !BUILD_DEFS.USE_WEBGPU_EP) { + if (epName === 'webgpu') { // perform WebGPU availability check if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); @@ -279,7 +270,7 @@ export const createSession = async ( const outputNamesUTF8Encoded = []; try { - [sessionOptionsHandle, allocs] = await setSessionOptions(options); + [sessionOptionsHandle, allocs] = setSessionOptions(options); if (options?.externalData && wasm.mountExternalData) { const loadingPromises = []; @@ -287,7 +278,7 @@ export const createSession = async ( const path = typeof file === 'string' ? file : file.path; loadingPromises.push( loadFile(typeof file === 'string' ? file : file.data).then((data) => { - wasm.mountExternalData(path, data); + wasm.mountExternalData!(path, data); }), ); } @@ -321,7 +312,6 @@ export const createSession = async ( } sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); - wasm.webgpuOnCreateSession?.(sessionHandle); if (sessionHandle === 0) { checkLastError("Can't create a session."); } @@ -454,7 +444,6 @@ export const releaseSession = (sessionId: number): void => { } wasm.jsepOnReleaseSession?.(sessionId); - wasm.webgpuOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); @@ -502,20 +491,11 @@ export const prepareInputOutputTensor = async ( const gpuBuffer = tensor[2].gpuBuffer; dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; - if (BUILD_DEFS.USE_WEBGPU_EP) { - const registerBuffer = wasm.webgpuRegisterBuffer; - if (!registerBuffer) { - throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); - } - - rawData = registerBuffer(gpuBuffer, sessionId); - } else { - const registerBuffer = wasm.jsepRegisterBuffer; - if (!registerBuffer) { - throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); - } - rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); + const registerBuffer = wasm.jsepRegisterBuffer; + if (!registerBuffer) { + throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); } + rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); } else if (location === 'ml-tensor') { const mlTensor = tensor[2].mlTensor as MLTensor; dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; @@ -811,7 +791,7 @@ export const run = async ( // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU // tensor for it. There is no mapping GPU buffer for an empty tensor. if (preferredLocation === 'gpu-buffer' && size > 0) { - const getBuffer = BUILD_DEFS.USE_WEBGPU_EP ? wasm.webgpuGetBuffer : wasm.jsepGetBuffer; + const getBuffer = wasm.jsepGetBuffer; if (!getBuffer) { throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.'); } @@ -824,43 +804,20 @@ export const run = async ( // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; - if (BUILD_DEFS.USE_WEBGPU_EP) { - wasm.webgpuRegisterBuffer!(gpuBuffer, sessionId, dataOffset); - const downloadDataFunction = wasm.webgpuCreateDownloader!(gpuBuffer, bufferSize, sessionId); - output.push([ - type, - dims, - { - gpuBuffer, - download: async () => { - const arrayBuffer = await downloadDataFunction(); - const data = new (tensorTypeToTypedArrayConstructor(type!))(arrayBuffer); - return data as Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]; - }, - dispose: () => { - if (wasm._OrtReleaseTensor(tensor) !== 0) { - checkLastError("Can't release tensor."); - } - }, - }, - 'gpu-buffer', - ]); - } else { - output.push([ - type, - dims, - { - gpuBuffer, - download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type), - dispose: () => { - if (wasm._OrtReleaseTensor(tensor) !== 0) { - checkLastError("Can't release tensor."); - } - }, + output.push([ + type, + dims, + { + gpuBuffer, + download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type), + dispose: () => { + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError("Can't release tensor."); + } }, - 'gpu-buffer', - ]); - } + }, + 'gpu-buffer', + ]); } else if (preferredLocation === 'ml-tensor' && size > 0) { const ensureTensor = wasm.jsepEnsureTensor; if (!ensureTensor) { @@ -930,18 +887,6 @@ export const run = async ( } finally { wasm.stackRestore(beforeRunStack); - if (BUILD_DEFS.USE_WEBGPU_EP) { - inputTensors.forEach((t) => { - if (t && t[3] === 'gpu-buffer') { - wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer); - } - }); - outputTensors.forEach((t) => { - if (t && t[3] === 'gpu-buffer') { - wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer); - } - }); - } inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); inputOutputAllocs.forEach((p) => wasm._free(p)); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 9b2ec71fd351d..b4871e145f4d7 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -41,6 +41,18 @@ export declare namespace JSEP { type DownloadTensorFunction = (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; export interface Module extends WebGpuModule, WebNnModule { + /** + * Mount the external data file to an internal map, which will be used during session initialization. + * + * @param externalDataFilePath - specify the relative path of the external data file. + * @param externalDataFileData - specify the content data. + */ + mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void; + /** + * Unmount all external data files from the internal map. + */ + unmountExternalData(): void; + /** * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and @@ -282,21 +294,6 @@ export declare namespace JSEP { } } -export declare namespace WebGpu { - export interface Module { - webgpuInit(setDefaultDevice: (device: GPUDevice) => void): void; - webgpuRegisterDevice( - device?: GPUDevice, - ): undefined | [deviceId: number, instanceHandle: number, deviceHandle: number]; - webgpuOnCreateSession(sessionHandle: number): void; - webgpuOnReleaseSession(sessionHandle: number): void; - webgpuRegisterBuffer(buffer: GPUBuffer, sessionHandle: number, bufferHandle?: number): number; - webgpuUnregisterBuffer(buffer: GPUBuffer): void; - webgpuGetBuffer(bufferHandle: number): GPUBuffer; - webgpuCreateDownloader(gpuBuffer: GPUBuffer, size: number, sessionHandle: number): () => Promise; - } -} - export interface OrtInferenceAPIs { _OrtInit(numThreads: number, loggingLevel: number): number; @@ -361,13 +358,7 @@ export interface OrtInferenceAPIs { logVerbosityLevel: number, optimizedModelFilePath: number, ): number; - _OrtAppendExecutionProvider( - sessionOptionsHandle: number, - name: number, - providerOptionsKeys: number, - providerOptionsValues: number, - numKeys: number, - ): Promise; + _OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number; _OrtAddFreeDimensionOverride(sessionOptionsHandle: number, name: number, dim: number): number; _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; _OrtReleaseSessionOptions(sessionOptionsHandle: number): number; @@ -382,11 +373,8 @@ export interface OrtInferenceAPIs { /** * The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten. */ -export interface OrtWasmModule - extends EmscriptenModule, - OrtInferenceAPIs, - Partial, - Partial { +export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial { + PTR_SIZE: number; // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; @@ -399,31 +387,7 @@ export interface OrtWasmModule stringToUTF8(str: string, offset: number, maxBytes: number): void; // #endregion - // #region ORT shared - - readonly PTR_SIZE: 4 | 8; - - /** - * Mount the external data file to an internal map, which will be used during session initialization. - * - * @param externalDataFilePath - specify the relative path of the external data file. - * @param externalDataFileData - specify the content data. - */ - mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void; - /** - * Unmount all external data files from the internal map. - */ - unmountExternalData(): void; - - /** - * This function patches the WebAssembly module to support Asyncify. This function should be called at least once - * before any ORT API is called. - */ - asyncInit?(): void; - - // #endregion - // #region config - readonly numThreads?: number; + numThreads?: number; // #endregion } diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index a8e27f6f334bc..871b575d71edc 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -11,39 +11,6 @@ import { isNode } from './wasm-utils-env'; */ const origin = isNode || typeof location === 'undefined' ? undefined : location.origin; -/** - * Some bundlers (eg. Webpack) will rewrite `import.meta.url` to a file URL at compile time. - * - * This function checks if `import.meta.url` starts with `file:`, but using the `>` and `<` operators instead of - * `startsWith` function so that code minimizers can remove the dead code correctly. - * - * For example, if we use terser to minify the following code: - * ```js - * if ("file://hard-coded-filename".startsWith("file:")) { - * console.log(1) - * } else { - * console.log(2) - * } - * - * if ("file://hard-coded-filename" > "file:" && "file://hard-coded-filename" < "file;") { - * console.log(3) - * } else { - * console.log(4) - * } - * ``` - * - * The minified code will be: - * ```js - * "file://hard-coded-filename".startsWith("file:")?console.log(1):console.log(2),console.log(3); - * ``` - * - * (use Terser 5.39.0 with default options, https://try.terser.org/) - * - * @returns true if the import.meta.url is hardcoded as a file URI. - */ -export const isEsmImportMetaUrlHardcodedAsFileUri = - BUILD_DEFS.IS_ESM && BUILD_DEFS.ESM_IMPORT_META_URL! > 'file:' && BUILD_DEFS.ESM_IMPORT_META_URL! < 'file;'; - const getScriptSrc = (): string | undefined => { // if Nodejs, return undefined if (isNode) { @@ -59,22 +26,9 @@ const getScriptSrc = (): string | undefined => { // new URL('actual-bundle-name.js', import.meta.url).href // ``` // So that bundler can preprocess the URL correctly. - if (isEsmImportMetaUrlHardcodedAsFileUri) { + if (BUILD_DEFS.ESM_IMPORT_META_URL?.startsWith('file:')) { // if the rewritten URL is a relative path, we need to use the origin to resolve the URL. - - // The following is a workaround for Vite. - // - // Vite uses a bundler(rollup/rolldown) that does not rewrite `import.meta.url` to a file URL. So in theory, this - // code path should not be executed in Vite. However, the bundler does not know it and it still try to load the - // following pattern: - // - `return new URL('filename', import.meta.url).href` - // - // By replacing the pattern above with the following code, we can skip the resource loading behavior: - // - `const URL2 = URL; return new URL2('filename', import.meta.url).href;` - // - // And it still works in Webpack. - const URL2 = URL; - return new URL(new URL2(BUILD_DEFS.BUNDLE_FILENAME, BUILD_DEFS.ESM_IMPORT_META_URL).href, origin).href; + return new URL(new URL(BUILD_DEFS.BUNDLE_FILENAME, BUILD_DEFS.ESM_IMPORT_META_URL).href, origin).href; } return BUILD_DEFS.ESM_IMPORT_META_URL; diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 98e61c9f87fbb..6006de62b41b6 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -27,8 +27,7 @@ const args = minimist(process.argv.slice(2)); * --bundle-mode=node * Build a single ort-web bundle for nodejs. */ -const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = - process.env.npm_config_bundle_mode || args['bundle-mode'] || 'prod'; +const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = args['bundle-mode'] || 'prod'; /** * --debug @@ -42,18 +41,7 @@ const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = * Enable debug mode. In this mode, esbuild metafile feature will be enabled. Full bundle analysis will be saved to a * file as JSON. */ -const DEBUG = process.env.npm_config_debug || args.debug; // boolean|'verbose'|'save' - -/** - * --webgpu-ep - * --no-webgpu-ep (default) - * - * Enable or disable the use of WebGPU EP. If enabled, the WebGPU EP will be used. If disabled, the WebGPU backend will - * be used with JSEP. - * - * (temporary) This flag is used to test the WebGPU EP integration. It will be removed in the future. - */ -const USE_WEBGPU_EP = process.env.npm_config_webgpu_ep ?? args['webgpu-ep'] ?? false; +const DEBUG = args.debug; // boolean|'verbose'|'save' /** * Root folder of the source code: `/js/` @@ -69,7 +57,6 @@ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'false', - 'BUILD_DEFS.USE_WEBGPU_EP': JSON.stringify(!!USE_WEBGPU_EP), 'BUILD_DEFS.IS_ESM': 'false', 'BUILD_DEFS.ESM_IMPORT_META_URL': 'undefined', @@ -136,17 +123,13 @@ async function minifyWasmModuleJsForBrowser(filepath: string): Promise { // ``` // with: // ``` - // new Worker((() => { - // const URL2 = URL; - // return import.meta.url > 'file:' && import.meta.url < 'file;' - // ? new URL2(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url) - // : new URL(import.meta.url); - // })(), ... + // new Worker(import.meta.url.startsWith('file:') + // ? new URL(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url) + // : new URL(import.meta.url), ... // ``` // // NOTE: this is a workaround for some bundlers that does not support runtime import.meta.url. - // - // Check more details in the comment of `isEsmImportMetaUrlHardcodedAsFileUri()` and `getScriptSrc()` in file `lib/wasm/wasm-utils-import.ts`. + // TODO: in emscripten 3.1.61+, need to update this code. // First, check if there is exactly one occurrence of "new Worker(new URL(import.meta.url)". const matches = [...contents.matchAll(/new Worker\(new URL\(import\.meta\.url\),/g)]; @@ -159,12 +142,7 @@ async function minifyWasmModuleJsForBrowser(filepath: string): Promise { // Replace the only occurrence. contents = contents.replace( /new Worker\(new URL\(import\.meta\.url\),/, - `new Worker((() => { - const URL2 = URL; - return (import.meta.url > 'file:' && import.meta.url < 'file;') - ? new URL2(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url) - : new URL(import.meta.url); - })(),`, + `new Worker(import.meta.url.startsWith('file:')?new URL(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url):new URL(import.meta.url),`, ); // Use terser to minify the code with special configurations: diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index 008d58530ee36..6429845d23df9 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -348,128 +348,6 @@ } ] }, - { - "name": "ConvTranspose NHWC- group - A", - "operator": "ConvTranspose", - "inputShapeDefinitions": "rankOnly", - "opset": { "domain": "", "version": 17 }, - "attributes": [ - { "name": "kernel_shape", "data": [1, 1], "type": "ints" }, - { "name": "group", "data": 2, "type": "int" } - ], - "cases": [ - { - "name": "T[0]", - "inputs": [ - { - "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0], - "dims": [1, 2, 3, 3], - "type": "float32" - }, - { - "data": [1.0, 2.0], - "dims": [2, 1, 1, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 36, 40, 44, 48, 52, 56, 60, 64, 68], - "dims": [1, 2, 3, 3], - "type": "float32" - } - ] - } - ] - }, - { - "name": "ConvTranspose NHWC- group - B", - "operator": "ConvTranspose", - "inputShapeDefinitions": "rankOnly", - "opset": { "domain": "", "version": 17 }, - "attributes": [ - { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, - { "name": "group", "data": 3, "type": "int" } - ], - "cases": [ - { - "name": "T[0]", - "inputs": [ - { - "data": [ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, - 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0 - ], - "dims": [1, 3, 3, 3], - "type": "float32" - }, - { - "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], - "dims": [3, 1, 2, 2], - "type": "float32" - }, - { - "data": [0.125, 0.25, 0.375], - "dims": [3], - "type": "float32" - } - ], - "outputs": [ - { - "data": [ - 0.125, 1.125, 4.125, 4.125, 3.125, 13.125, 23.125, 18.125, 15.125, 43.125, 53.125, 36.125, 18.125, 45.125, - 52.125, 32.125, 45.25, 104.25, 115.25, 66.25, 123.25, 279.25, 305.25, 172.25, 159.25, 357.25, 383.25, - 214.25, 105.25, 232.25, 247.25, 136.25, 162.375, 351.375, 370.375, 200.375, 387.375, 833.375, 875.375, - 470.375, 231.375, 494.375, 517.375, 276.375, 0.375, 0.375, 0.375, 0.375 - ], - "dims": [1, 3, 4, 4], - "type": "float32" - } - ] - } - ] - }, - { - "name": "ConvTranspose NHWC- group - C", - "operator": "ConvTranspose", - "inputShapeDefinitions": "rankOnly", - "opset": { "domain": "", "version": 17 }, - "attributes": [ - { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, - { "name": "group", "data": 3, "type": "int" } - ], - "cases": [ - { - "name": "T[0]", - "inputs": [ - { - "data": [ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, - 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 - ], - "dims": [1, 3, 3, 4], - "type": "float32" - }, - { - "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], - "dims": [3, 1, 2, 2], - "type": "float32" - } - ], - "outputs": [ - { - "data": [ - 0, 1, 4, 7, 6, 4, 16, 26, 36, 26, 20, 56, 66, 76, 50, 24, 59, 66, 73, 44, 60, 137, 148, 159, 90, 164, 368, - 394, 420, 234, 212, 472, 498, 524, 290, 140, 307, 322, 337, 184, 216, 465, 484, 503, 270, 516, 1104, 1146, - 1188, 634, 596, 1272, 1314, 1356, 722, 352, 747, 770, 793, 420 - ], - "dims": [1, 3, 4, 5], - "type": "float32" - } - ] - } - ] - }, { "name": "ConvTranspose with bias addition C", "operator": "ConvTranspose", diff --git a/js/web/test/e2e/exports/main.js b/js/web/test/e2e/exports/main.js index d8c7bbf69039f..8ed22a6784e7c 100644 --- a/js/web/test/e2e/exports/main.js +++ b/js/web/test/e2e/exports/main.js @@ -3,7 +3,7 @@ 'use strict'; -const { runDevTest, runProdTest, verifyAssets } = require('./test'); +const { runDevTest, runProdTest } = require('./test'); const { installOrtPackages } = require('./utils'); /** @@ -29,14 +29,5 @@ module.exports = async function main(PRESERVE, PACKAGES_TO_INSTALL) { await runDevTest('vite-default', '\x1b[32m➜\x1b[39m \x1b[1mLocal\x1b[22m:', 5173); await runProdTest('vite-default', '\x1b[32m➜\x1b[39m \x1b[1mLocal\x1b[22m:', 4173); - - await verifyAssets('vite-default', async (cwd) => { - const globby = await import('globby'); - - return { - test: 'File "dist/assets/**/ort.*.mjs" should not exist', - success: globby.globbySync('dist/assets/**/ort.*.mjs', { cwd }).length === 0, - }; - }); } }; diff --git a/js/web/test/e2e/exports/test.js b/js/web/test/e2e/exports/test.js index e2bcffea97519..9c5ed745ab0b5 100644 --- a/js/web/test/e2e/exports/test.js +++ b/js/web/test/e2e/exports/test.js @@ -121,29 +121,7 @@ async function runProdTest(testCaseName, ready, port) { await runTest(testCaseName, ['prod'], ready, 'npm run start', port); } -async function verifyAssets(testCaseName, testers) { - testers = Array.isArray(testers) ? testers : [testers]; - const wd = path.join(__dirname, 'testcases', testCaseName); - - console.log(`[${testCaseName}] Verifying assets...`); - - const testResults = []; - - try { - for (const tester of testers) { - testResults.push(await tester(wd)); - } - - if (testResults.some((r) => !r.success)) { - throw new Error(`[${testCaseName}] asset verification failed.`); - } - } finally { - console.log(`[${testCaseName}] asset verification result:`, testResults); - } -} - module.exports = { runDevTest, runProdTest, - verifyAssets, }; diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc b/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc deleted file mode 100644 index 65c14e8cb0bdd..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/webgpu/shader_helper.h" -#include "core/providers/webgpu/webgpu_supported_types.h" -#include "contrib_ops/webgpu/bert/bias_add.h" -#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" - -namespace onnxruntime { -namespace contrib { -namespace webgpu { - -ONNX_OPERATOR_KERNEL_EX( - BiasAdd, - kMSDomain, - 1, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", WebGpuSupportedFloatTypes()), - BiasAdd); - -Status BiasAddProgram::GenerateShaderCode(ShaderHelper& shader) const { - const ShaderVariableHelper& input = shader.AddInput("input"); - const ShaderVariableHelper& bias = shader.AddInput("bias"); - const ShaderVariableHelper& residual = shader.AddInput("residual"); - const ShaderVariableHelper& output = shader.AddOutput("output"); - - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << "let value = " << input.GetByOffset("global_idx") - << " + " << bias.GetByOffset("global_idx % uniforms.channels") - << " + " << residual.GetByOffset("global_idx") << ";\n" - << output.SetByOffset("global_idx", "value"); - - return Status::OK(); -} - -static int64_t GetMaxComponents(int64_t size) { - if (size % 4 == 0) { - return 4; - } else if (size % 2 == 0) { - return 2; - } - return 1; -} - -Status BiasAdd::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { - const auto* input = context.Input(0); - const auto* bias = context.Input(1); - const auto* residual = context.Input(2); - - TensorShape input_shape = input->Shape(); - - if (input_shape.NumDimensions() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasAdd input should have 3 dimensions."); - } - - int64_t channels = input_shape[2]; - int64_t components = GetMaxComponents(channels); - channels /= components; - - TensorShape bias_shape = bias->Shape(); - if (bias_shape.NumDimensions() != 1 || bias_shape[0] != channels) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasAdd bias should have 1 dimension with size equal to the number of channels."); - } - - auto* output = context.Output(0, input_shape); - int64_t output_size = output->Shape().Size() / components; - - BiasAddProgram program{}; - program.AddInputs({{input}, {bias}, {residual}}) - .AddOutput({output}) - .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({{static_cast(output_size)}, - {static_cast(channels)}}); - return context.RunProgram(program); -} - -} // namespace webgpu -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_add.h b/onnxruntime/contrib_ops/webgpu/bert/bias_add.h deleted file mode 100644 index 58cc5f09f8003..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/bert/bias_add.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/webgpu/program.h" -#include "core/providers/webgpu/webgpu_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace webgpu { - -using namespace onnxruntime::webgpu; -using onnxruntime::webgpu::ComputeContext; - -class BiasAddProgram final : public Program { - public: - BiasAddProgram() : Program{"BiasAdd"} {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, - {"channels", ProgramUniformVariableDataType::Uint32}); -}; - -class BiasAdd final : public WebGpuKernel { - public: - BiasAdd(const OpKernelInfo& info) : WebGpuKernel(info) {} - Status ComputeInternal(ComputeContext& context) const override; -}; - -} // namespace webgpu -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index 29ea4f81dd5e1..a5cae7e7f6747 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -50,7 +50,7 @@ Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) c const auto* bias = context.Input(1); auto* output = context.Output(0, input->Shape()); - uint32_t data_size = onnxruntime::narrow(output->Shape().Size()); + uint32_t data_size = gsl::narrow(output->Shape().Size()); if (data_size == 0) { return Status::OK(); } @@ -60,7 +60,7 @@ Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) c int bias_components = 1; if (bias != nullptr) { - bias_size = onnxruntime::narrow(bias->Shape().Size()); + bias_size = gsl::narrow(bias->Shape().Size()); if (bias_size % 4 == 0) { bias_components = 4; bias_size = bias_size / 4; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 1e95d3d9610ff..57ae8a7e5ba74 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -98,7 +98,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, {present_value, ProgramTensorMetadataDependency::Rank, components}}) .AddIndices(valid_present_shape); - program.SetDispatchGroupSize(onnxruntime::narrow(valid_kv_size + 63 / 64)) + program.SetDispatchGroupSize(gsl::narrow(valid_kv_size + 63 / 64)) .SetWorkgroupSize(64) .CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_) .AddUniformVariables({{static_cast(valid_kv_size)}, @@ -379,7 +379,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { if (sg_size > 8) { for (var i:u32 = 0; i < qkv_head_size_vec; i++) { - var val = v_tile[capped_sg_id][i]; + var val = select(vec4(0), v_tile[capped_sg_id][i], k_start + capped_sg_id < seq_causal_length); var sum = subgroupShuffle(val, 0) * qk_1[0]; sum += subgroupShuffle(val, 1) * qk_1[1]; sum += subgroupShuffle(val, 2) * qk_1[2]; diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index 20e1583e0da8f..bc8b7493fc916 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -66,11 +66,11 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con const auto* sin_cache = context.Input(3); auto* output = context.Output(0, input_shape); - const auto batch_size = onnxruntime::narrow(input->Shape()[0]); - const auto batch_stride = onnxruntime::narrow(input_shape.SizeFromDimension(1)); - const auto sequence_length = onnxruntime::narrow(input_shape[input_shape.NumDimensions() - 2]); + const auto batch_size = gsl::narrow(input->Shape()[0]); + const auto batch_stride = gsl::narrow(input_shape.SizeFromDimension(1)); + const auto sequence_length = gsl::narrow(input_shape[input_shape.NumDimensions() - 2]); const auto hidden_size = batch_stride / sequence_length; - const auto half_rotary_embedding_dim = onnxruntime::narrow(cos_cache->Shape()[1]); + const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape @@ -85,11 +85,11 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con std::vector global_dims(rank); std::vector global_strides(rank); for (size_t j = 0; j < rank; ++j) { - global_dims[j] = onnxruntime::narrow(global_shape[j]); - global_strides[j] = onnxruntime::narrow(global_shape.SizeFromDimension(j + 1)); + global_dims[j] = gsl::narrow(global_shape[j]); + global_strides[j] = gsl::narrow(global_shape.SizeFromDimension(j + 1)); } - const auto output_size = onnxruntime::narrow(global_shape.Size()); + const auto output_size = gsl::narrow(global_shape.Size()); RotaryEmbeddingProgram program{interleaved_}; const auto input_output_strides = input_shape.NumDimensions() == 3 diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc index d5d4632c01e2a..a1840257d734f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -122,7 +122,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo } const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - const uint32_t hidden_size = onnxruntime::narrow(x_shape[x_shape.NumDimensions() - 1]); + const uint32_t hidden_size = gsl::narrow(x_shape[x_shape.NumDimensions() - 1]); const int components = GetMaxComponents(hidden_size); const bool has_input_skip_bias_sum = input_skip_bias_sum != nullptr; @@ -133,7 +133,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) - .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * data_size / hidden_size))) + .SetDispatchGroupSize(gsl::narrow(ceil(1.0 * data_size / hidden_size))) .AddUniformVariables({ {static_cast(components)}, }) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc deleted file mode 100644 index 05cbfb1f99c48..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ /dev/null @@ -1,326 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" -#include "core/providers/webgpu/shader_helper.h" - -namespace onnxruntime { -namespace contrib { -namespace webgpu { - -Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("scales", ShaderUsage::UseUniform); - shader.AdditionalImplementation() << R"ADDNL_FN( - fn readInput(offset: u32) -> input_a_value_t - { - if (offset > uniforms.input_size) { - return input_a_value_t(0); - } - return input_a[offset]; - } - )ADDNL_FN"; - shader.MainFunctionBody() << R"MAIN_FN( - var local_a : array, 32>; - var max_value:vec4 = vec4(0); - for (var idx:u32=0;idx<32;idx+=1) - { - local_a[idx] = readInput(workgroup_idx*32 + idx); - max_value = max(max_value, abs(local_a[idx])); - } - var scale = max(max_value.x, max_value.y); - scale = max(scale, max_value.z); - scale = max(scale, max_value.w); - for (var idx:u32=0;idx<32;idx+=1) - { - output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); - } - // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. - scales[workgroup_idx] = scale/127; - )MAIN_FN"; - return Status::OK(); -} - -Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddInput("scales_a", ShaderUsage::UseUniform); - shader.AddInput("input_b", ShaderUsage::UseUniform); - shader.AddInput("scales_b", ShaderUsage::UseUniform); - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - - // This shader implements co-operative matrix multiply. The key idea here is to - // assume there is a primitive for medium size matrix multiply a subgroup can perform, - // using all its lanes and pooling all its registers to keep the values in registry. - // - // The entire workgroup which has N subgroups first loads a tile into shared memory, - // Then each subgroup loads a subtile from shared memory into registers and uses - // the medium size matrix multiply primitive to perform the math. - // The values for tile/subtile size are chosen to conform to the resource limits - // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - - // therefore there are 16 subgroups and 16 lanes in each subgroup. - // K the hidden dimension is paged in from RAM at k tile size which is 64. - // All this puts the shared memory requirement slightly above 16KB. - // WebGPU limit is 16KB, output is moved to registers instead of SHM to make - // everything fit in shared memory. - // - // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with - // subgroup shuffle as a placeholder for the day the medium matrix mul primitive - // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on - // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the - // 512B of registry from each lane. - // - // The medium size matmul is implemented using dot4I8Packed, so the inputs for - // this shader require A to be int8 quantized with block size 64. B is regular - // matmulnbits input with block size 32. - - shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";"; - - shader.AdditionalImplementation() << R"ADDNL_FN( - const tile_size = 64; - const subtile_size = 16; - const tile_size_k = 32; - const vec_factor = 4; - const u32_factor = 4; - const tile_size_k_vec = 2; - - // Shared memory - var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 - var scale_A : array; // 64 x 1 - var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 - var scale_B : array; // 64 x 1 - - fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let a_global = a_global_base + row; - if (a_global >= uniforms.M) - { - return; - } - tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; - if (col == 0) - { - // kidx_v - covers 16 values of k - scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; - } - } - - fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let b_global = b_global_base + row; - if (b_global >= uniforms.N) - { - return; - } - - let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; - var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); - var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); - b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - if (col == 0) - { - // kidx_v - each kidx_v covers 16 values of k - scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; - } - } - - // Scaled dot product of 8 packed unsigned integers. - fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t - { - var local_sum = dot4I8Packed(a1[0], b1[0]); - local_sum += dot4I8Packed(a1[1], b1[1]); - local_sum += dot4I8Packed(a1[2], b1[2]); - local_sum += dot4I8Packed(a1[3], b1[3]); - local_sum += dot4I8Packed(a2[0], b2[0]); - local_sum += dot4I8Packed(a2[1], b2[1]); - local_sum += dot4I8Packed(a2[2], b2[2]); - local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(local_sum) * scale; - } - )ADDNL_FN"; - - shader.MainFunctionBody() << R"MAIN_FN( - // During the load phase we use all 256 threads to load 64 rows of A/B. - // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. - let a_global_base = workgroup_id.x * tile_size; - let b_global_base = workgroup_id.y * tile_size; - let load_AorB = u32(local_idx/128); - let load_row = u32((local_idx%128)/2); - let load_col = u32(local_idx%2); - - // During the compute phase, we have the 64x64 tile split into - // subtiles of 16x16. We have a grid of 4x4 subtiles. - let subtile_id = u32(local_idx / subtile_size); - let subtile_idx = u32(subtile_id / 4); - let subtile_idy = u32(subtile_id % 4); - let base_A = subtile_idx * 16; - let base_B = subtile_idy * 16; - // For each subtile we have 16 threads assigned. - let a_idx = u32(local_idx % subtile_size); - - var lane_output1: vec4; - var lane_output2: vec4; - var lane_output3: vec4; - var lane_output4: vec4; - // K's vectrorization is 16 items per index. See input_a/input_b. - // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is - // k tile size is 32. In vectorized space that is 32/16 = 2. - for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) - { - // Load Phase: Populate shared memory for the workgroup. - if (load_AorB == 0) - { - loadSHMA(a_global_base, kidx_v, load_row, load_col); - } - else - { - loadSHMB(b_global_base, kidx_v, load_row, load_col); - } - workgroupBarrier(); - - // Compute phase: Perform matmul for this subtile 16 x 32 x 16. - // Step 1: Load from shared memory into registers across entire subgroup. - var own_a0: vec4 = tile_A[0][base_A + a_idx]; - var own_a1: vec4 = tile_A[1][base_A + a_idx]; - var own_scale_a: output_element_t = scale_A[base_A + a_idx]; - if (sg_size == 16) - { - var own_b0: vec4 = tile_B[0][base_B + sg_id]; - var own_b1: vec4 = tile_B[1][base_B + sg_id]; - var own_scale_b: output_element_t = scale_B[base_B + sg_id]; - // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. - lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); - lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); - lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); - lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); - - lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); - lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); - lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); - lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); - - lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); - lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); - lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); - lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); - - lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); - lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); - lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); - lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); - } - else - { - // Code for other subgroup sizes, simply doesnt use subgroups at all. - // Relies on reads from single location tile_B[][base_B + col] by all - // being optimized by the hardware. - lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); - lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); - lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); - lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); - - lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); - lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); - lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); - lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); - - lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); - lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); - lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); - lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); - - lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); - lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); - lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); - lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); - } - workgroupBarrier(); - } - - let a_global = a_global_base + base_A + a_idx; - let b_global = b_global_base + base_B; - let output_idx = ((a_global) * uniforms.N + b_global)/4; - // This creates a shader requirement that uniforms.N % 16 == 0 - if (a_global < uniforms.M && b_global < uniforms.N) - { - output[output_idx] = lane_output1; - output[output_idx+1] = lane_output2; - output[output_idx+2] = lane_output3; - output[output_idx+3] = lane_output4; - } - )MAIN_FN"; - - return Status::OK(); -} - -Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, - uint32_t M, - uint32_t N, - uint32_t K, - uint32_t block_size, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y) { - constexpr uint32_t kVec4Components = 4; - constexpr uint32_t kVec2Components = 2; - constexpr uint32_t kU32Components = 4; - - constexpr uint32_t kBlockSizeA = 128; - DP4AMatMulQuantizeProgram quantize_program; - quantize_program.SetWorkgroupSize(1); - quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); - TensorShape a_quant_shape{1, M, K / kU32Components}; - Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); - TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); - Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); - quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}}) - .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1}, - {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}}) - .AddUniformVariable({static_cast(M * K / kVec4Components)}); - ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); - - constexpr uint32_t kTileSize = 64; - TensorShape reshaped_y_shape{1, M, N / kVec4Components}; - DP4AMatMulNBitsProgram mul_program{block_size}; - mul_program.SetWorkgroupSize(256); - mul_program.SetDispatchGroupSize( - (M + kTileSize - 1) / kTileSize, - (N + kTileSize - 1) / kTileSize, 1); - mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, - {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, - {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec2Components * kU32Components)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) - .AddUniformVariables({{static_cast(M)}, - {static_cast(N)}, - {static_cast(K)}, - {static_cast(K / 8)}, - {static_cast(K / 16)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(kVec4Components)}) - .CacheHint("Block" + std::to_string(block_size)); - return context.RunProgram(mul_program); -} - -bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, - uint64_t accuracy_level, - uint32_t block_size, - uint32_t batch_count, - uint32_t N, - uint32_t K, - uint32_t components_k, - bool has_zero_points) { - // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. - // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 - bool use_dp4a = context.Device().HasFeature(wgpu::FeatureName::Subgroups) && - context.AdapterInfo().backendType != wgpu::BackendType::Metal; - return (accuracy_level == 4 && block_size % 32 == 0 && - batch_count == 1 && components_k == 4 && K % 64 == 0 && N % 16 == 0 && - !has_zero_points && use_dp4a); -} - -} // namespace webgpu -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h deleted file mode 100644 index 15b86d78301ad..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/webgpu/program.h" -#include "core/providers/webgpu/webgpu_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace webgpu { - -using namespace onnxruntime::webgpu; - -class DP4AMatMulQuantizeProgram final : public Program { - public: - DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}); -}; - -class DP4AMatMulNBitsProgram final : public Program { - public: - DP4AMatMulNBitsProgram(uint32_t block_size) : Program{"DP4AMatMulNBits"}, block_size_(block_size) {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"M", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"K8", ProgramUniformVariableDataType::Uint32}, - {"K16", ProgramUniformVariableDataType::Uint32}); - - private: - uint32_t block_size_; -}; - -Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, - uint32_t M, - uint32_t N, - uint32_t K, - uint32_t block_size, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y); - -bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, - uint64_t accuracy_level, - uint32_t block_size, - uint32_t batch_count, - uint32_t N, - uint32_t K, - uint32_t components_k, - bool has_zero_points); - -} // namespace webgpu -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index cce10a59fbd4b..28d622b2c9c33 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -5,7 +5,6 @@ #include "contrib_ops/webgpu/quantization/matmul_nbits.h" #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" -#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/shader_helper.h" @@ -372,7 +371,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } } else { const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); - const int output_element_number = y.NumComponents() * onnxruntime::narrow(output_number_); + const int output_element_number = y.NumComponents() * gsl::narrow(output_number_); const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; std::string offset = "workgroup_idx * " + std::to_string(output_number_); @@ -533,6 +532,255 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddOutput("output", ShaderUsage::UseUniform); + shader.AddOutput("scales", ShaderUsage::UseUniform); + shader.AdditionalImplementation() << R"ADDNL_FN( + fn readInput(offset: u32) -> input_a_value_t + { + if (offset > uniforms.input_size) { + return input_a_value_t(0); + } + return input_a[offset]; + } +)ADDNL_FN"; + shader.MainFunctionBody() << R"MAIN_FN( + var local_a : array, 32>; + var max_value:vec4 = vec4(0); + for (var idx:u32=0;idx<32;idx+=1) + { + local_a[idx] = readInput(workgroup_idx*32 + idx); + max_value = max(max_value, abs(local_a[idx])); + } + var scale = max(max_value.x, max_value.y); + scale = max(scale, max_value.z); + scale = max(scale, max_value.w); + for (var idx:u32=0;idx<32;idx+=1) + { + output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); + } + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. + scales[workgroup_idx] = scale/127; +)MAIN_FN"; + return Status::OK(); +} + +Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scales_a", ShaderUsage::UseUniform); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + // This shader implements co-operative matrix multiply. The key idea here is to + // assume there is a primitive for medium size matrix multiply a subgroup can perform, + // using all its lanes and pooling all its registers to keep the values in registry. + // + // The entire workgroup which has N subgroups first loads a tile into shared memory, + // Then each subgroup loads a subtile from shared memory into registers and uses + // the medium size matrix multiply primitive to perform the math. + // The values for tile/subtile size are chosen to conform to the resource limits + // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - + // therefore there are 16 subgroups and 16 lanes in each subgroup. + // K the hidden dimension is paged in from RAM at k tile size which is 64. + // All this puts the shared memory requirement slightly above 16KB. + // WebGPU limit is 16KB, output is moved to registers instead of SHM to make + // everything fit in shared memory. + // + // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with + // subgroup shuffle as a placeholder for the day the medium matrix mul primitive + // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on + // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the + // 512B of registry from each lane. + // + // The medium size matmul is implemented using dot4I8Packed, so the inputs for + // this shader require A to be int8 quantized with block size 64. B is regular + // matmulnbits input with block size 32. + + shader.AdditionalImplementation() << R"ADDNL_FN( + const tile_size = 64; + const subtile_size = 16; + const tile_size_k = 32; + const vec_factor = 4; + const u32_factor = 4; + const tile_size_k_vec = 2; + const block_size = 32; + + // Shared memory + var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_A : array; // 64 x 1 + var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_B : array; // 64 x 1 + + fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let a_global = a_global_base + row; + if (a_global >= uniforms.M) + { + return; + } + tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; + if (col == 0) + { + // kidx_v - covers 16 values of k + scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; + } + } + + fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let b_global = b_global_base + row; + if (b_global >= uniforms.N) + { + return; + } + + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + if (col == 0) + { + // kidx_v - each kidx_v covers 16 values of k + scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2]; + } + } + + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t + { + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(local_sum) * scale; + } +)ADDNL_FN"; + + shader.MainFunctionBody() << R"MAIN_FN( + // During the load phase we use all 256 threads to load 64 rows of A/B. + // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. + let a_global_base = workgroup_id.x * tile_size; + let b_global_base = workgroup_id.y * tile_size; + let load_AorB = u32(local_idx/128); + let load_row = u32((local_idx%128)/2); + let load_col = u32(local_idx%2); + + // During the compute phase, we have the 64x64 tile split into + // subtiles of 16x16. We have a grid of 4x4 subtiles. + let subtile_id = u32(local_idx / subtile_size); + let subtile_idx = u32(subtile_id / 4); + let subtile_idy = u32(subtile_id % 4); + let base_A = subtile_idx * 16; + let base_B = subtile_idy * 16; + // For each subtile we have 16 threads assigned. + let a_idx = u32(local_idx % subtile_size); + + var lane_output1: vec4; + var lane_output2: vec4; + var lane_output3: vec4; + var lane_output4: vec4; + // K's vectrorization is 16 items per index. See input_a/input_b. + // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is + // k tile size is 32. In vectorized space that is 32/16 = 2. + for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) + { + // Load Phase: Populate shared memory for the workgroup. + if (load_AorB == 0) + { + loadSHMA(a_global_base, kidx_v, load_row, load_col); + } + else + { + loadSHMB(b_global_base, kidx_v, load_row, load_col); + } + workgroupBarrier(); + + // Compute phase: Perform matmul for this subtile 16 x 32 x 16. + // Step 1: Load from shared memory into registers across entire subgroup. + var own_a0: vec4 = tile_A[0][base_A + a_idx]; + var own_a1: vec4 = tile_A[1][base_A + a_idx]; + var own_scale_a: output_element_t = scale_A[base_A + a_idx]; + if (sg_size == 16) + { + var own_b0: vec4 = tile_B[0][base_B + sg_id]; + var own_b1: vec4 = tile_B[1][base_B + sg_id]; + var own_scale_b: output_element_t = scale_B[base_B + sg_id]; + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); + lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); + lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); + lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); + + lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); + lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); + lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); + lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); + + lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); + lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); + lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); + lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); + + lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); + lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); + lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); + lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); + } + else + { + // Code for other subgroup sizes, simply doesnt use subgroups at all. + // Relies on reads from single location tile_B[][base_B + col] by all + // being optimized by the hardware. + lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); + lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); + lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); + lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); + + lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); + lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); + lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); + lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); + + lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); + lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); + lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); + lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); + + lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); + lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); + lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); + lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); + } + workgroupBarrier(); + } + + let a_global = a_global_base + base_A + a_idx; + let b_global = b_global_base + base_B; + let output_idx = ((a_global) * uniforms.N + b_global)/4; + // This creates a shader requirement that uniforms.N % 16 == 0 + if (a_global < uniforms.M && b_global < uniforms.N) + { + output[output_idx] = lane_output1; + output[output_idx+1] = lane_output2; + output[output_idx+2] = lane_output3; + output[output_idx+3] = lane_output4; + } +)MAIN_FN"; + + return Status::OK(); +} + Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -548,16 +796,16 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context TensorShape b_shape({N_, K_}); ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); auto* y = context.Output(0, helper.OutputShape()); - const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); + const uint32_t data_size = gsl::narrow(y->Shape().Size()); if (data_size == 0) { return Status::OK(); } - const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); - const uint32_t M = onnxruntime::narrow(helper.M()); - const uint32_t N = onnxruntime::narrow(helper.N()); - const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_); + const uint32_t batch_count = gsl::narrow(helper.OutputOffsets().size()); + const uint32_t M = gsl::narrow(helper.M()); + const uint32_t N = gsl::narrow(helper.N()); + const uint32_t K = gsl::narrow(helper.K()); + const uint32_t block_size = gsl::narrow(block_size_); constexpr uint32_t nbits = 4; const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; @@ -574,17 +822,56 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); } - if (M >= kMinMForTileOptimization && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { - return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, context, y); + const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); + // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. + // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 + const bool use_dp4a = has_subgroup && context.AdapterInfo().backendType != wgpu::BackendType::Metal; + if (accuracy_level_ == 4 && block_size == 32 && + batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 && + !has_zero_points && use_dp4a && M >= kMinMForTileOptimization) { + constexpr uint32_t kVec4Components = 4; + constexpr uint32_t kVec2Components = 2; + constexpr uint32_t kU32Components = 4; + + constexpr uint32_t kBlockSizeA = 128; + DP4AMatMulQuantizeProgram quantize_program; + quantize_program.SetWorkgroupSize(1); + quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); + TensorShape a_quant_shape{1, M, K / kU32Components}; + Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); + TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); + Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); + quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}}) + .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow(1)}, + {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}) + .AddUniformVariable({static_cast(M * K / kVec4Components)}); + ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); + + constexpr uint32_t kTileSize = 64; + TensorShape reshaped_y_shape{1, M, N / kVec4Components}; + DP4AMatMulNBitsProgram mul_program; + mul_program.SetWorkgroupSize(256); + mul_program.SetDispatchGroupSize( + (M + kTileSize - 1) / kTileSize, + (N + kTileSize - 1) / kTileSize, 1); + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, + {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}, + {static_cast(K / 8)}, + {static_cast(K / 16)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}); + return context.RunProgram(mul_program); } // TODO: Support output_number > 1. Some cases are failed when output_number > 1. constexpr uint32_t output_number = 1; const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; - const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); const bool use_subgroup = has_subgroup && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32; - MatMulNBitsProgram program{output_number, block_size, tile_m, static_cast(components_b), has_zero_points, use_subgroup}; + MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow(components_b), has_zero_points, use_subgroup}; if (M > kMinMForTileOptimization && block_size == 32) { components = 1; constexpr uint32_t workgroup_size = 64; @@ -597,8 +884,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context program.CacheHint("T_M" + std::to_string(tile_m) + "Subgroup" + std::to_string(use_subgroup)); } else if (block_size == 32) { components = 1; - // TODO: Tune the workgroup size when `M=1`. - constexpr uint32_t workgroup_size = 128; + constexpr uint32_t workgroup_size = 64; const uint32_t workgroup_y = N % 8 == 0 ? 8 : 1; const uint32_t workgroup_x = workgroup_size / workgroup_y; program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); @@ -614,10 +900,10 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context TensorShape reshaped_y_shape{batch_count, M, N / components}; program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, static_cast(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, static_cast(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, {scales, ProgramTensorMetadataDependency::None}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(components)}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) .AddUniformVariable({block_size}); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 10221e19c7400..3d72629bf6b25 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -35,6 +35,25 @@ class MatMulNBitsProgram final : public Program { bool use_subgroup_; }; +class DP4AMatMulQuantizeProgram final : public Program { + public: + DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}); +}; + +class DP4AMatMulNBitsProgram final : public Program { + public: + DP4AMatMulNBitsProgram() : Program{"DP4AMatMulNBits"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K8", ProgramUniformVariableDataType::Uint32}, + {"K16", ProgramUniformVariableDataType::Uint32}); +}; + class MatMulNBits final : public WebGpuKernel { public: MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index cb024d2a758a9..2944a4d61b8ef 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -185,13 +185,13 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te mul_program.SetDispatchGroupSize( (N + kTileSizeB - 1) / kTileSizeB, (M + kTileSizeA - 1) / kTileSizeA, 1); - mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}, - {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kU32Components)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) + mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) .AddUniformVariables({{static_cast(M)}, {static_cast(N)}, {static_cast(K)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1}); + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, gsl::narrow(1)}); return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 068a94c7390e2..2e7ed5a16a2f0 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -37,8 +37,8 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/compute_capability.h b/onnxruntime/core/framework/compute_capability.h index 819264b3960e7..5f21ba2f013e0 100644 --- a/onnxruntime/core/framework/compute_capability.h +++ b/onnxruntime/core/framework/compute_capability.h @@ -2,11 +2,8 @@ // Licensed under the MIT License. #pragma once -#include #include "core/common/common.h" #include "core/graph/indexed_sub_graph.h" -#include "core/graph/graph.h" -#include "core/optimizer/graph_optimizer_registry.h" namespace onnxruntime { // A structure encodes a subgraph and the method to run it. @@ -24,22 +21,5 @@ struct ComputeCapability { ComputeCapability(std::unique_ptr t_sub_graph) : sub_graph(std::move(t_sub_graph)) {} - - // Optional function to optimize this ComputeCapability. - // This will be called by ORT once the ComputeCapability is assigned to the EP. - std::function - optimization_func; - - // Optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized. - // when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made. - // IndexedSubGraph.nodes: - // - update based on RemovedNode/AddNode calls - // IndexedSubGraph.MetaDef (if present): - // - inputs and outputs will be unchanged - // - constant_initializers MAY change if we constant fold an initializer during optimization - std::vector> nodes_to_optimize; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index df85daa006a43..3a937a119d03b 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -14,7 +14,6 @@ namespace onnxruntime { std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry&, IResourceAccountant*) const { std::vector> result; for (const auto& node : graph.Nodes()) { diff --git a/onnxruntime/core/framework/external_data_loader.cc b/onnxruntime/core/framework/external_data_loader.cc index c577805e69cc4..fe73a55735631 100644 --- a/onnxruntime/core/framework/external_data_loader.cc +++ b/onnxruntime/core/framework/external_data_loader.cc @@ -60,12 +60,7 @@ common::Status LoadWebAssemblyExternalData(const Env& env, break; case 1: // Load external data to GPU. - // TODO: use a unified interface for upload external buffer. - if (Module.webgpuUploadExternalBuffer) { - Module.webgpuUploadExternalBuffer(dataIdOrBuffer, data); - } else { - Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); - } + Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); break; default: return 4; // Unknown error occurred in memory copy. diff --git a/onnxruntime/core/framework/external_data_loader.h b/onnxruntime/core/framework/external_data_loader.h index 90d48ca800797..117da7d0a4afa 100644 --- a/onnxruntime/core/framework/external_data_loader.h +++ b/onnxruntime/core/framework/external_data_loader.h @@ -42,7 +42,7 @@ class IExternalDataLoader { enum class ExternalDataLoadType { CPU = 0, -#if defined(USE_JSEP) || defined(USE_WEBGPU) +#if defined(USE_JSEP) WEBGPU_BUFFER = 1, #endif }; diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index d3e435c0341b0..1eb7420b44d2c 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - #include "core/framework/fallback_cpu_capability.h" #include "core/common/inlined_containers.h" @@ -178,5 +176,3 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe } } // namespace onnxruntime - -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h index ddcc1de96d2af..bca75adbfd5a7 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.h +++ b/onnxruntime/core/framework/fallback_cpu_capability.h @@ -3,8 +3,6 @@ #pragma once -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - #include #include "core/common/inlined_containers_fwd.h" #include "core/framework/execution_provider.h" // for IExecutionProvider::IKernelLookup @@ -28,5 +26,3 @@ std::unordered_set GetCpuPreferredNodes(const GraphViewer& graph, const logging::Logger& logger); } // namespace onnxruntime - -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index ff4d300f665b1..111f8e0a5fc34 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -142,15 +142,13 @@ struct GetCapabilityForEPParams { std::reference_wrapper debug_graph_fn; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) IResourceAccountant* resource_accountant; - std::reference_wrapper graph_optimizer_registry; }; auto get_capabilities = [](const IExecutionProvider& ep, const GraphViewer& graph_viewer, const IExecutionProvider::IKernelLookup& kernel_lookup, - IResourceAccountant* resource_accountant, - const GraphOptimizerRegistry& graph_optimizer_registry) { - auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, resource_accountant); + IResourceAccountant* resource_accountant) { + auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup, resource_accountant); // In theory an EP could return an empty capability. Remove those. capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), @@ -184,11 +182,10 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l auto& graph = params.graph.get(); auto& capabilities = params.capabilities.get(); - const auto& graph_optimizer_registry = params.graph_optimizer_registry.get(); { const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant); if (capabilities.empty()) { return Status::OK(); @@ -226,7 +223,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l capabilities.clear(); const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant); // all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities InlinedHashSet new_nodes_in_capabilities; @@ -264,7 +261,6 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, const KernelRegistryManager& kernel_registry_mgr, const IExecutionProvider& current_ep, - const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger, std::vector>& capabilities) { const auto& ep_type = current_ep.Type(); @@ -276,62 +272,14 @@ static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, logger}; // TODO: Provide EP with a capability to look inside the functions. - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, nullptr, graph_optimizer_registry); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, nullptr); return Status::OK(); } /** - * Check whether the given IndexedSubGraph is available for assigning to a specific provider. - * - */ -static bool IsIndexedSubGraphAvailableForAssignment(Graph& graph, - const IndexedSubGraph& capability, - GraphPartitioner::Mode mode, - const std::string& provider_type) { - // The provider can run a single node in the if not using meta-defs. - if (capability.GetMetaDef() == nullptr && capability.nodes.size() == 1) { - auto* node = graph.GetNode(capability.nodes[0]); - if (nullptr != node && node->GetExecutionProviderType().empty()) { - // The node was not fused or assigned. - return true; - } - return false; - } - - // if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned, - // so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by - // preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to - // and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes - // should never be on those lists. - // - // when the ORT format model is loaded we will process it normally with EP priority being applied for - // whichever EPs are enabled at the time. - // - // e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP. - // We want the ORT format model to be able to be run as efficiently as possible on either platform, - // so we want all the nodes that either may take to be preserved. If we did not do this we would - // need to create one ORT format model for Android and one for iOS. - if (mode == GraphPartitioner::Mode::kAssignOnly) { - return true; - } - - for (auto node_index : capability.nodes) { - const auto* node = graph.GetNode(node_index); - if ((nullptr == node) || - (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) { - // The node was fused or assigned, so that the whole sub-graph will not be assigned to this - // The assumption is that this can only run the sub-graph as a whole unit. - return false; - } - } - - return true; -} - -/** - * Return a fused node or assign the nodes in the indexed subgraph to the current EP. - * + * Check if a node can be placed on a specific provider. + * Do nothing if the node is already assigned * \param graph * \param capability * \param kernel_registry_mgr @@ -350,42 +298,75 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, if (nullptr == capability.GetMetaDef()) { TryAssignSingleNode(graph, capability, provider_type); } else { - const bool acc_enabled = capability.IsAccountingEnabled(); - if (mode == GraphPartitioner::Mode::kNormal) { - std::ostringstream oss; - oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++; - std::string node_name = oss.str(); + // The can run a fused in the . - Node* fused_node = nullptr; - if (fusion_style == IExecutionProvider::FusionStyle::Function) { - fused_node = &graph.FuseSubGraph(capability, node_name); - } else { - // create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed - // through to Compile via a filtered GraphViewer. - fused_node = &graph.BeginFuseSubGraph(capability, node_name); + // Check whether any node in the was already assigned. If so it cannot be stolen as assignment is done + // in order of EP priority + bool sub_graph_available_for_assignment = true; + if (mode != GraphPartitioner::Mode::kAssignOnly) { + // if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned, + // so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by + // preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to + // and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes + // should never be on those lists. + // + // when the ORT format model is loaded we will process it normally with EP priority being applied for + // whichever EPs are enabled at the time. + // + // e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP. + // We want the ORT format model to be able to be run as efficiently as possible on either platform, + // so we want all the nodes that either may take to be preserved. If we did not do this we would + // need to create one ORT format model for Android and one for iOS. + for (auto node_index : capability.nodes) { + const auto* node = graph.GetNode(node_index); + if ((nullptr == node) || + (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) { + // The node was fused or assigned, so that the whole sub-graph will not be assigned to this + // The assumption is that this can only run the sub-graph as a whole unit. + sub_graph_available_for_assignment = false; + break; + } } + } - fused_node->SetExecutionProviderType(provider_type); - if (acc_enabled) { - // We account for the fused node. We operate under assumption - // that the fused node would use no more memory when the nodes we are fusing. - // and potentially less than that, and therefore, no threshold check is needed here. - // All threshold checks are done within the EP. - capability.ComputeAndAccountForNode(*fused_node); - } + if (sub_graph_available_for_assignment) { + const bool acc_enabled = capability.IsAccountingEnabled(); + if (mode == GraphPartitioner::Mode::kNormal) { + std::ostringstream oss; + oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++; + std::string node_name = oss.str(); + + Node* fused_node = nullptr; + if (fusion_style == IExecutionProvider::FusionStyle::Function) { + fused_node = &graph.FuseSubGraph(capability, node_name); + } else { + // create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed + // through to Compile via a filtered GraphViewer. + fused_node = &graph.BeginFuseSubGraph(capability, node_name); + } - result = fused_node; - } else { - // assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them. - // This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion - // at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device - // capabilities. - for (size_t i = 0, limit = capability.nodes.size(); i < limit; ++i) { - auto* node = graph.GetNode(capability.nodes[i]); - if (node != nullptr) { - node->SetExecutionProviderType(provider_type); - if (acc_enabled) { - capability.AccountForNode(i); + fused_node->SetExecutionProviderType(provider_type); + if (acc_enabled) { + // We account for the fused node. We operate under assumption + // that the fused node would use no more memory when the nodes we are fusing. + // and potentially less than that, and therefore, no threshold check is needed here. + // All threshold checks are done within the EP. + capability.ComputeAndAccountForNode(*fused_node); + } + + result = fused_node; + } else { + // assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them. + // This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion + // at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device + // capabilities. + for (size_t i = 0, limit = capability.nodes.size(); i < limit; ++i) { + auto* node = graph.GetNode(capability.nodes[i]); + if (node != nullptr) { + node->SetExecutionProviderType(provider_type); + if (acc_enabled) { + capability.AccountForNode(i); + } } } } @@ -405,8 +386,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, const layout_transformation::DebugGraphFn& debug_graph_fn, - const logging::Logger& logger, IResourceAccountant* resource_accountant, - const GraphOptimizerRegistry& graph_optimizer_registry) { + const logging::Logger& logger, IResourceAccountant* resource_accountant) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -420,7 +400,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn, logger, resource_accountant, graph_optimizer_registry)); + transform_layout_fn, debug_graph_fn, logger, resource_accountant)); } } @@ -444,8 +424,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, mode, std::cref(transform_layout_fn), std::cref(debug_graph_fn), - resource_accountant, - std::ref(graph_optimizer_registry)}; + resource_accountant}; ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { @@ -471,30 +450,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, entry->sub_graph->GetMetaDef() != nullptr; })); for (auto& capability : capabilities) { - // The can run a fused in the . - // Check whether any node in the was already assigned. If so it cannot be stolen as assignment is done - // in order of EP priority - bool sub_graph_available_for_assignment = IsIndexedSubGraphAvailableForAssignment(graph, *capability->sub_graph, mode, type); - - // If the is available to be assigned to the EP and the ComputeCapability has nodes_to_optimize, - // run EP related optimizations and update ComputeCapability. - if (sub_graph_available_for_assignment && !capability->nodes_to_optimize.empty()) { - for (auto& optimization_cc : capability->nodes_to_optimize) { - if (optimization_cc->optimization_func) { - auto status = optimization_cc->optimization_func(graph, *optimization_cc, *capability, graph_optimizer_registry); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, "The optimization function failed to finish."); - } - // #TODO: Handle nested optimization ComputeCapability - } - } - } - - Node* n = nullptr; - if (sub_graph_available_for_assignment) { - n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); - } - + Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); if (n != nullptr) { // searching in kernel registries, if no kernel registered for the fused_node, use compile approach if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) { @@ -631,7 +587,6 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, const KernelRegistryManager& kernel_registry_mgr, Graph& graph, - const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger, InlinedHashSet& not_inlined, size_t& inlined_count) { @@ -648,7 +603,6 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_mgr, *subgraph, - graph_optimizer_registry, logger, not_inlined, inlined_count)); @@ -673,7 +627,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, graph_optimizer_registry, logger, + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger, capabilities)); for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; @@ -713,28 +667,23 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide } // Validate the ep_context_path to make sure it is file path and check whether the file exist already -static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_path, - const std::filesystem::path& model_path, - std::filesystem::path& context_cache_path) { +static Status EpContextFilePathCheck(const std::string& ep_context_path, + const std::filesystem::path& model_path) { + std::filesystem::path context_cache_path; if (!ep_context_path.empty()) { context_cache_path = ep_context_path; if (!context_cache_path.has_filename()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder."); } } else if (!model_path.empty()) { - auto pos = model_path.native().find_last_of(ORT_TSTR(".")); - if (pos != std::string::npos) { - context_cache_path = model_path.native().substr(0, pos) + ORT_TSTR("_ctx.onnx"); - } else { - context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); - } + context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty."); } if (std::filesystem::exists(context_cache_path)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", - context_cache_path, "' exist already. Please remove the EP context model if you want to re-generate it."); + context_cache_path, "' exist already."); } return Status::OK(); @@ -765,7 +714,15 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers }; std::filesystem::path context_cache_path; - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_path, graph.ModelPath(), context_cache_path)); + const std::filesystem::path& model_path = graph.ModelPath(); + + if (!ep_context_path.empty()) { + context_cache_path = ep_context_path; + } else if (!model_path.empty()) { + context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty"); + } Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path @@ -837,7 +794,6 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager, const std::optional& acc_map, - const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger) { bool modified_graph = false; @@ -861,7 +817,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, partition_params.debug_graph_fn, - logger, resource_accountant, graph_optimizer_registry)); + logger, resource_accountant)); } // expand any nodes that have an ONNX function definition but no matching ORT kernel. @@ -882,7 +838,6 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params, KernelRegistryManager& kernel_registry_mgr, IExecutionProvider& current_ep, - const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability @@ -898,7 +853,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param PartitionParams subgraph_partition_params = partition_params; subgraph_partition_params.graph = std::ref(subgraph); ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, - current_ep, graph_optimizer_registry, logger)); + current_ep, logger)); } } @@ -914,8 +869,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param std::cref(partition_params.transform_layout_function), std::cref(partition_params.debug_graph_fn), #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - nullptr, - std::ref(graph_optimizer_registry) + nullptr }; // clang-format on @@ -1008,11 +962,10 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param static Status PartitionOrtFormatModel(const PartitionParams& partition_params, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager, - const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger) { // process full graph with each EP for (const auto& ep : execution_providers) { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, graph_optimizer_registry, logger)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger)); } return Status::OK(); @@ -1039,7 +992,6 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_manager, graph, - *graph_optimizer_registry_, logger, not_inlined, inlined_count)); @@ -1096,7 +1048,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, std::ref(*fused_kernel_registry), std::ref(fused_node_unique_id), std::cref(transform_layout_function), - std::cref(debug_graph_fn)}; + std::cref(debug_graph_fn), + }; #else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1115,8 +1068,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (ep_context_enabled) { std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); // Check before EP compile graphs - std::filesystem::path context_cache_path; - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_path, graph.ModelPath(), context_cache_path)); + ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath())); } // We use this only if Resource Aware Partitioning is enabled for any of the EPs @@ -1125,7 +1077,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, ORT_RETURN_IF_ERROR(NodeStatsRecorder::CreateAccountants(config_options, graph.ModelPath(), ep_acc_map)); ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, - ep_acc_map, *graph_optimizer_registry_, logger)); + ep_acc_map, logger)); if (ep_context_enabled) { std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); @@ -1139,7 +1091,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, *graph_optimizer_registry_, logger)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger)); } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index b9d4022cb5a14..d1ef193cf1520 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -7,7 +7,6 @@ #include "core/graph/graph.h" #include "core/framework/fuse_nodes_funcs.h" #include "core/framework/transform_layout_functions.h" -#include "core/optimizer/graph_optimizer_registry.h" namespace onnxruntime { @@ -25,12 +24,9 @@ class GraphPartitioner { }; // The order of providers represents the user preference. - GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, - const ExecutionProviders& providers, - std::unique_ptr graph_optimizer_registry) + GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers) : kernel_registry_mgr_(kernel_registry_mgr), - providers_(providers), - graph_optimizer_registry_(std::move(graph_optimizer_registry)) { + providers_(providers) { } // Run partitioning. @@ -68,7 +64,6 @@ class GraphPartitioner { KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; - std::unique_ptr graph_optimizer_registry_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 1c446840b7938..a884927abddb7 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -10,8 +10,8 @@ #include "core/framework/sparse_tensor.h" #include "core/graph/onnx_protobuf.h" #include "core/session/ort_apis.h" -#include "core/session/model_editor_api.h" #include "core/framework/error_code_helper.h" + #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_map_type_info.h" #include "core/framework/onnxruntime_sequence_type_info.h" @@ -40,7 +40,7 @@ OrtTypeInfo::OrtTypeInfo(std::unique_ptr optional_type_info : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info)) {} OrtTypeInfo::OrtTypeInfo(ONNXType type, std::unique_ptr data) noexcept - : type(type), tensor_type_info(std::move(data)) { + : type(type), data(std::move(data)) { } OrtTypeInfo::~OrtTypeInfo() = default; @@ -55,9 +55,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeI ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input, _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN - *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) - ? input->tensor_type_info.get() - : nullptr; + *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data.get() : nullptr; return nullptr; API_IMPL_END } @@ -86,8 +84,8 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeI API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, - _Out_ const char** const out, _Out_ size_t* len) { +ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const char** const out, + _Out_ size_t* len) { API_IMPL_BEGIN *out = type_info->denotation.c_str(); *len = type_info->denotation.size(); @@ -95,61 +93,6 @@ ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* API_IMPL_END } -#if !defined(ORT_MINIMAL_BUILD) -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, - _Out_ OrtTypeInfo** type_info) { - API_IMPL_BEGIN - auto ti = std::make_unique(ONNXType::ONNX_TYPE_TENSOR); - ti->tensor_type_info = tensor_info->Clone(); - *type_info = ti.release(); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, - _Out_ OrtTypeInfo** type_info) { - API_IMPL_BEGIN - auto ti = std::make_unique(ONNXType::ONNX_TYPE_SPARSETENSOR); - ti->tensor_type_info = tensor_info->Clone(); - *type_info = ti.release(); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, - _In_ const OrtTypeInfo* map_value_type, _Out_ OrtTypeInfo** type_info) { - API_IMPL_BEGIN - auto ti = std::make_unique(ONNXType::ONNX_TYPE_MAP); - ti->map_type_info = std::make_unique(map_key_type, map_value_type->Clone()); - *type_info = ti.release(); - - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, - _Out_ OrtTypeInfo** type_info) { - API_IMPL_BEGIN - auto ti = std::make_unique(ONNXType::ONNX_TYPE_SEQUENCE); - ti->sequence_type_info = std::make_unique(sequence_type->Clone()); - *type_info = ti.release(); - - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, - _Out_ OrtTypeInfo** type_info) { - API_IMPL_BEGIN - auto ti = std::make_unique(ONNXType::ONNX_TYPE_OPTIONAL); - ti->optional_type_info = std::make_unique(contained_type->Clone()); - *type_info = ti.release(); - - return nullptr; - API_IMPL_END -} -#endif // !defined(ORT_MINIMAL_BUILD) - ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { std::unique_ptr p(ptr); } @@ -355,8 +298,8 @@ std::unique_ptr OrtTypeInfo::Clone() const { #endif case ONNX_TYPE_TENSOR: { std::unique_ptr info; - if (tensor_type_info) { - info = tensor_type_info->Clone(); + if (data) { + info = data->Clone(); } result = MakePtr(type, std::move(info)); result->denotation = denotation; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 54bb946e0d36b..72d263d5fa442 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -31,7 +31,7 @@ struct OrtTypeInfo { ONNXType type; std::string denotation; - std::unique_ptr tensor_type_info; + std::unique_ptr data; std::unique_ptr map_type_info; std::unique_ptr sequence_type_info; std::unique_ptr optional_type_info; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 9d45ec38e5a32..83a353615bc35 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -81,11 +81,6 @@ static common::Status ExtDataTensorProtoToTensor(const Env& env, ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, ext_data_buf, ext_data_len, ext_data_deleter, buffered_tensor, &prepacked_for_graph)); - if constexpr (endian::native != endian::little) { - if (!proto_path.empty() && (proto_path.compare(onnxruntime::utils::kTensorProtoMemoryAddressTag) != 0)) { - utils::ConvertRawDataInTensorProto(const_cast(&tensor_proto), ext_data_buf, ext_data_len); - } - } // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -208,12 +203,13 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } } -common::Status AllocateTensor(const onnxruntime::MemBuffer* m, - std::unique_ptr& p_tensor, - const onnxruntime::DataTypeImpl* const& type, - onnxruntime::TensorShape& tensor_shape, - bool use_device_allocator_for_initializers, - const onnxruntime::AllocatorPtr& alloc) { +common::Status AllocateTensor( + const onnxruntime::MemBuffer* m, + std::unique_ptr& p_tensor, + const onnxruntime::DataTypeImpl* const& type, + onnxruntime::TensorShape& tensor_shape, + bool use_device_allocator_for_initializers, + const onnxruntime::AllocatorPtr& alloc) { if (m != nullptr) { p_tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); if (m->GetLen() < p_tensor->SizeInBytes()) { @@ -358,7 +354,6 @@ common::Status SaveInitializedTensors( } ORT_RETURN_IF_ERROR(planner.Trace(entry.first, entry.second)); } - // 2. allocate weight buffer on different locations // planned_initializers_memory_size_in_byte is not actual physical size. // It's the virtual size computed by planner. @@ -391,9 +386,6 @@ common::Status SaveInitializedTensors( if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { ort_value = *(session_options.initializers_to_share_map.at(name)); LOGS(logger, INFO) << "Using user supplied initializer with name (" << name << ")."; - - } else if (graph.GetOrtValueInitializer(name, ort_value)) { - // populated OrtValue from the Graph instance } else { const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second); @@ -405,9 +397,10 @@ common::Status SaveInitializedTensors( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; Tensor* p_tensor = nullptr; - auto buffered_tensors_iter = buffered_tensors.find(name); - if (buffered_tensors_iter != buffered_tensors.end()) { - p_tensor = buffered_tensors_iter->second.get(); + if (auto iter = buffered_tensors.find(name); + iter != buffered_tensors.end()) { + p_tensor = iter->second.release(); + buffered_tensors.erase(iter); } Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, @@ -419,12 +412,6 @@ common::Status SaveInitializedTensors( oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); return Status(st.Category(), st.Code(), oss.str()); } - - if (p_tensor != nullptr) { - // p_tensor was wrapped in a deleter by DeserializeTensorProto so we can simply release it here. - ORT_IGNORE_RETURN_VALUE(buffered_tensors_iter->second.release()); - buffered_tensors.erase(buffered_tensors_iter); - } } // 'name' is a reference to a string within the TensorProto that save_tensor_func may free diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 9bbea279da82d..418e46924fb9f 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -49,27 +49,10 @@ ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShape API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* info, +ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* this_ptr, _In_ const int64_t* dim_values, size_t dim_count) { API_IMPL_BEGIN - if (std::any_of(dim_values, dim_values + dim_count, [](int64_t v) { return v < -1; })) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "dim_values must be -1 (symbolic dimension) or larger."); - } - - auto num_dims = std::max(dim_count, info->dim_params.size()); - - // make shape and dim_values consistent - info->dim_params.resize(num_dims, ""); - - onnxruntime::TensorShapeVector dims; - dims.resize(num_dims, -1); - - for (size_t idx = 0; idx < dim_count; ++idx) { - dims[idx] = dim_values[idx]; - } - - info->shape = onnxruntime::TensorShape(dims); - + this_ptr->shape = onnxruntime::TensorShape(dim_values, dim_count); return nullptr; API_IMPL_END } @@ -105,22 +88,10 @@ ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, ORT_API_STATUS_IMPL(OrtApis::SetSymbolicDimensions, _In_ struct OrtTensorTypeAndShapeInfo* info, _In_ const char** names, _In_ size_t dim_params_length) { - auto num_dims = std::max(info->shape.NumDimensions(), dim_params_length); - - // make shape and dim_values consistent - if (num_dims > info->shape.NumDimensions()) { - auto dim_values = info->shape.AsShapeVector(); - dim_values.resize(num_dims, -1); - info->shape = onnxruntime::TensorShape(dim_values); - } - info->dim_params.clear(); - info->dim_params.resize(num_dims, ""); - for (size_t idx = 0; idx < dim_params_length; ++idx) { - info->dim_params[idx] = names[idx]; + info->dim_params.push_back(names[idx]); } - return nullptr; } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 94a2a6677358e..17c37b8882168 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -270,15 +270,10 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::str tensor_proto.set_raw_data(std::move(param)); } -void ConvertRawDataInTensorProto(TensorProto* tensor, - void* ext_data_buf, - size_t ext_data_len) { +void ConvertRawDataInTensorProto(TensorProto* tensor) { size_t element_size = 1; char* bytes = NULL; size_t num_elements = 0; - if (ext_data_buf && !ext_data_len) { - return; - } switch (tensor->data_type()) { case TensorProto_DataType_FLOAT: bytes = reinterpret_cast(tensor->mutable_float_data()->mutable_data()); @@ -342,15 +337,6 @@ void ConvertRawDataInTensorProto(TensorProto* tensor, num_elements = (tensor->raw_data().size()) / element_size; bytes = const_cast(tensor->mutable_raw_data()->c_str()); } - - if (element_size == 1) { - return; - } - if (ext_data_buf) { - ORT_ENFORCE(ext_data_len % element_size == 0); - num_elements = ext_data_len / element_size; - bytes = reinterpret_cast(ext_data_buf); - } for (size_t i = 0; i < num_elements; ++i) { char* start_byte = bytes + i * element_size; char* end_byte = start_byte + element_size - 1; @@ -1331,15 +1317,22 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const auto* raw_data = tensor.DataRaw(); ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. // use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. auto offset = narrow(reinterpret_cast(raw_data)); - ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, - offset, tensor.SizeInBytes(), tensor_proto); - + ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("location"); + entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("offset"); + entry->set_value(std::to_string(offset)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("length"); + entry->set_value(std::to_string(tensor.SizeInBytes())); } else { utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 79eae48c10411..f5dec7ae988f2 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -41,18 +41,12 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, ExternalDataInfo::PrepackedInfos* prepacked_infos = nullptr); /** * This function is used to convert the endianess of Tensor data. - * If ext_data_buf is provided, then this buffer content's endianess - * will be changed. * Mostly, will be used in big endian system to support the model file * generated on little endian system. - * @param tensor_proto given initializer tensor - * @param ext_data_buf optional externl data buffer - * @param ext_data_len optional externl data buffer lengeh + * @param initializer given initializer tensor * @returns None */ -void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto* tensor_proto, - void* ext_data_buf = NULL, - size_t ext_data_len = 0); +void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto* initializer); /** * Wrapper function for set_raw_data. diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 39ffc6a5b0cee..e4915616b7b7c 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -7,34 +7,30 @@ #include #include #include -#include #include - -#include +#include #include "core/common/common.h" +#include #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" -#include "core/framework/tensor_type_and_shape.h" #include "core/flatbuffers/schema/ort.fbs.h" -#include "core/framework/tensor_external_data_info.h" #include "core/framework/tensor_shape.h" -#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/tensor_external_data_info.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" -#include "core/graph/function_utils.h" #include "core/graph/graph_flatbuffers_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/model.h" -#include "core/graph/model_editor_api_types.h" #include "core/graph/model_load_utils.h" #include "core/graph/model_saving_options.h" #include "core/graph/node_attr_utils.h" #include "core/graph/op.h" #include "core/graph/runtime_optimization_record_container.h" +#include "core/graph/function_utils.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/graph/function.h" @@ -3504,10 +3500,6 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { #if !defined(DISABLE_SPARSE_TENSORS) sparse_tensor_names_.erase(tensor_name); #endif - - // doesn't matter if it existed or not - ORT_IGNORE_RETURN_VALUE(ortvalue_initializers_.erase(tensor_name)); - SetGraphResolveNeeded(); } else { #if !defined(DISABLE_SPARSE_TENSORS) @@ -3639,8 +3631,8 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( return Status::OK(); } - #endif // DISABLE_EXTERNAL_INITIALIZERS + #endif // !defined(ORT_MINIMAL_BUILD) bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto*& value) const { @@ -3653,16 +3645,6 @@ bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorPro return true; } -bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value) const { - auto it = ortvalue_initializers_.find(name); - if (it == ortvalue_initializers_.end()) { - return false; - } - - value = it->second; - return true; -} - void Graph::CleanAllInitializedTensors() noexcept { name_to_initial_tensor_.clear(); #if !defined(DISABLE_SPARSE_TENSORS) @@ -3678,8 +3660,6 @@ void Graph::CleanAllInitializedTensors() noexcept { delete graph_proto_->mutable_initializer()->ReleaseCleared(); } #endif - - ortvalue_initializers_.clear(); } const ONNX_NAMESPACE::TensorProto* Graph::GetConstantInitializer(const std::string& initializer_name, @@ -3729,14 +3709,13 @@ void Graph::AddValueInfo(const NodeArg* new_value_info) { value_info_.insert(new_value_info); } -template -std::vector Graph::CreateNodeArgs(const StringRange& names, +std::vector Graph::CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, const ArgNameToTypeMap& name_to_type_map) { const auto name_to_type_map_end = name_to_type_map.end(); std::vector results; results.reserve(names.size()); - for (const std::string& name : names) { + for (auto& name : names) { const TypeProto* type = nullptr; auto name_to_type_iter = name_to_type_map.find(name); @@ -4097,51 +4076,27 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { // This is used for constructing full path for external data // if it exists - auto add_initializer = [](TensorList& output_initializers, const TensorProto& initializer) -> void { - TensorProto& output = *output_initializers.Add(); - output = initializer; - - // copy any in-memory external data into raw data - if (utils::HasExternalData(initializer)) { - const std::filesystem::path ignored; - std::basic_string location; - onnxruntime::FileOffsetType file_offset; - SafeInt tensor_byte_size; - - ORT_THROW_IF_ERROR(utils::GetExternalDataInfo(initializer, ignored, location, file_offset, tensor_byte_size)); - - if (location == onnxruntime::utils::kTensorProtoMemoryAddressTag) { - // file_offset is address - void* data = reinterpret_cast(file_offset); - - // set in raw data - output.clear_data_location(); - output.set_raw_data(data, tensor_byte_size); - } - } - }; - - auto* mutable_initializers = result.mutable_initializer(); - #if !defined(DISABLE_SPARSE_TENSORS) const auto& model_path = ModelPath(); // We want to make sure that sparse initializers do not appear // as dense duplicates within the initializers list. - const bool has_sparse_initializers = !sparse_tensor_names_.empty(); - const auto sparse_end = sparse_tensor_names_.end(); - for (const auto& initializer : graph_proto_->initializer()) { - if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(initializer.name())) { - add_initializer(*mutable_initializers, initializer); - } else { - auto& sparse_initializer = *result.add_sparse_initializer(); - auto status = utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer); - ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); + if (!sparse_tensor_names_.empty()) { + const auto sparse_end = sparse_tensor_names_.end(); + auto* mutable_initializer = result.mutable_initializer(); + for (const auto& initializer : graph_proto_->initializer()) { + if (sparse_end == sparse_tensor_names_.find(initializer.name())) { + *mutable_initializer->Add() = initializer; + } else { + auto& sparse_initializer = *result.add_sparse_initializer(); + auto status = utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer); + ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); + } } + } else { + *result.mutable_initializer() = graph_proto_->initializer(); } #else - for (const auto& initializer : graph_proto_->initializer()) { - add_initializer(*mutable_initializers, initializer); - } + *result.mutable_initializer() = graph_proto_->initializer(); #endif return result; @@ -5390,9 +5345,6 @@ Status Graph::InlineFunction(Node& callnode) { } void Graph::SetInputs(gsl::span inputs) { - graph_inputs_including_initializers_.clear(); - graph_inputs_excluding_initializers_.clear(); - // creating graph from scratch // rely on SetGraphInputsOutputs() to fix up graph_inputs_excluding_initializers_ // if is_loaded_from_model_file_ == false @@ -5401,6 +5353,7 @@ void Graph::SetInputs(gsl::span inputs) { if (is_loaded_from_model_file_) { // graph loaded from model file + graph_inputs_excluding_initializers_.clear(); for (const auto* input : inputs) { ORT_ENFORCE(input->Exists(), "Input to set must exist."); if (name_to_initial_tensor_.find(input->Name()) == name_to_initial_tensor_.end()) { @@ -5417,7 +5370,6 @@ void Graph::SetInputs(gsl::span inputs) { } void Graph::SetOutputs(gsl::span outputs) { - graph_outputs_.clear(); graph_outputs_.reserve(outputs.size()); graph_outputs_.assign(outputs.begin(), outputs.end()); @@ -5736,207 +5688,4 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph return Status::OK(); } -#if !defined(ORT_MINIMAL_BUILD) -namespace { -ValueInfoProto OrtValueInfoToOnnx(const OrtValueInfo& vi) { - // the model builder API checks that the OrtValueInfo has a complete and valid OrtTypeInfo instance and that the - // name is not null/empty. - ORT_ENFORCE(vi.type_info->type == ONNX_TYPE_TENSOR, - "Internal error. Model Editor API should only allow OrtValueInfo for tensor to be created."); - - ValueInfoProto value_info_proto; - value_info_proto.set_name(vi.name); - - auto* tensor = value_info_proto.mutable_type()->mutable_tensor_type(); - const OrtTensorTypeAndShapeInfo& tensor_info = *vi.type_info->tensor_type_info.get(); - tensor->set_elem_type(tensor_info.type); - - auto& shape = *tensor->mutable_shape(); - - size_t idx = 0; - for (auto dim : tensor_info.shape.GetDims()) { - auto& dim_proto = *shape.add_dim(); - if (dim >= 0) { - dim_proto.set_dim_value(dim); - } else { - const std::string& dim_param = tensor_info.dim_params[idx]; - // if empty leave the new dim_proto with neither dim_value nor dim_param set. this represents an 'unknown' dim - if (!dim_param.empty()) { - dim_proto.set_dim_param(dim_param); - } - } - } - - return value_info_proto; -} -} // namespace - -Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updating_existing_graph) { - ArgNameToTypeMap name_to_type_map; - - // NOTE: need to create NodeArgs as we go along - - // add inputs first. the shape from an input for a non-const initializer is preferred, so we want to create the - // NodeArg for the value using that - - auto add_graph_inputs_outputs = [&, this]( - const InlinedVector>& graph_inputs_or_outputs, - bool is_input) { - // when updating a model we don't require the inputs or outputs to be set if they're unchanged. - if (updating_existing_graph && graph_inputs_or_outputs.empty()) { - return; - } - - std::vector node_args; - node_args.reserve(graph_inputs_or_outputs.size()); - for (auto& ort_value_info : graph_inputs_or_outputs) { - ValueInfoProto value_info = OrtValueInfoToOnnx(*ort_value_info); - - name_to_type_map[value_info.name()] = value_info.type(); - node_args.push_back(&GetOrCreateNodeArg(value_info.name(), &value_info.type())); - } - - if (is_input) { - SetInputs(node_args); - } else { - SetOutputs(node_args); - } - }; - - auto add_initializers = [this](const std::unordered_map>& initializers, - bool is_external) { - for (auto& name_and_ortvalue : initializers) { - // convert from OrtValue to TensorProto - const std::string& name = name_and_ortvalue.first; - OrtValue& v = *name_and_ortvalue.second; - - ORT_ENFORCE(v.IsTensor(), "Initializers must be Tensors"); - const Tensor& t = v.Get(); - TensorProto& tensor_proto = *graph_proto_->add_initializer(); - - tensor_proto.set_name(name); - tensor_proto.set_data_type(t.GetElementType()); - for (auto dim : t.Shape().GetDims()) { - tensor_proto.add_dims(dim); - } - - if (is_external) { - // pre-existing memory that we don't own. avoid a copy by storing the pointer in the ExternalDataInfo - const void* data_offset = t.DataRaw(); // address of memory not offset into file - auto offset = narrow(reinterpret_cast(data_offset)); - - ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, - offset, t.SizeInBytes(), tensor_proto); - - // add OrtValue to ortvalue_initializers_ to keep it alive and to store the deleter if provided. - ortvalue_initializers_.emplace(name, std::move(v)); - } else { - tensor_proto.set_raw_data(t.DataRaw(), t.SizeInBytes()); - } - - TypeProto type_proto{TypeProtoFromTensorProto(tensor_proto)}; - ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(name, &type_proto)); - - name_to_initial_tensor_.emplace(name, &tensor_proto); - } - }; - - // process graph inputs first as we want the type/shape from them to be preferred if a graph input - // has a matching initializer - add_graph_inputs_outputs(api_graph.inputs, /*input*/ true); - - // add initializers - ortvalue_initializers_.reserve(api_graph.external_initializers.size()); - add_initializers(api_graph.external_initializers, /*is_external*/ true); - add_initializers(api_graph.initializers, /*is_external*/ false); - - // add graph outputs - add_graph_inputs_outputs(api_graph.outputs, /*input*/ false); - - // add nodes - for (const auto& ort_node : api_graph.nodes) { - const OrtNode& node = *ort_node; - - // convert Constant nodes to initializers - if (node.operator_name == "Constant" && node.domain_name == kOnnxDomain) { - // graph_proto_ provides storage - TensorProto& tensor = *graph_proto_->add_initializer(); - - // create NodeProto from OrtNode so we can use the existing conversion functions - NodeProto node_proto; - - // 'Constant' node has no inputs or attributes - ORT_RETURN_IF_NOT(node.input_names.empty() && node.attributes.size() == 1 && node.output_names.size() == 1, - node.node_name, - " is an invalid 'Constant' node. " - "Must have no inputs, one attribute and one output. "); - - node_proto.add_attribute()->CopyFrom(node.attributes[0]); - node_proto.add_output(node.output_names[0]); - - node_proto.set_op_type(node.operator_name); - node_proto.set_name(node.node_name); - node_proto.set_domain(node.domain_name); - - ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(node_proto, /*model_path*/ "", tensor)); - name_to_initial_tensor_.emplace(node.output_names[0], &tensor); - - continue; - } - - auto input_defs = CreateNodeArgs(node.input_names, name_to_type_map); - auto output_defs = CreateNodeArgs(node.output_names, name_to_type_map); - - const auto num_attributes = node.attributes.size(); - - NodeAttributes attributes; - attributes.reserve(num_attributes); - - for (const auto& attr : node.attributes) { - attributes[attr.name()] = attr; - } - - ORT_IGNORE_RETURN_VALUE(AddNode(node.node_name, node.operator_name, /*doc_string*/ "", - input_defs, output_defs, &attributes, node.domain_name)); - } - - return Resolve(); -} - -// static -Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, - const Model& owning_model, - const std::unordered_map& domain_to_version, - IOnnxRuntimeOpSchemaCollectionPtr schema_registry, - bool strict_shape_type_inference, - const logging::Logger& logger, - std::unique_ptr& graph) { - graph = std::make_unique(owning_model, - domain_to_version, - schema_registry, - /*parent_graph*/ nullptr, /*parent_node*/ nullptr, - logger, - strict_shape_type_inference); - - return graph->LoadFromModelEditorApiModel(api_graph); -} - -Status Graph::UpdateUsingModelEditorApiModel(const OrtModel& api_model) { - for (auto& entry : api_model.domain_to_version) { - if (auto it = domain_to_version_.find(entry.first); it != domain_to_version_.end()) { - if (it->second != entry.second) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Domain version can not be changed for '", entry.first, - "'. Current version: ", it->second); - } - } else { - domain_to_version_.insert(entry); - } - } - - // this will replace inputs/outputs and add nodes. - return LoadFromModelEditorApiModel(*api_model.graph, /*updating_existing_graph*/ true); -} - -#endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 199aa79cc1dde..922759b02e75f 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -300,6 +300,8 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init const auto* fbs_raw_data = fbs_tensor.raw_data(); if (fbs_raw_data) { if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) { + initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); const void* data_offset = fbs_raw_data->Data(); // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. @@ -307,9 +309,15 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. auto offset = narrow(reinterpret_cast(data_offset)); - ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, - offset, fbs_raw_data->size(), initializer); - + ONNX_NAMESPACE::StringStringEntryProto* entry = initializer.mutable_external_data()->Add(); + entry->set_key("location"); + entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); + entry = initializer.mutable_external_data()->Add(); + entry->set_key("offset"); + entry->set_value(std::to_string(offset)); + entry = initializer.mutable_external_data()->Add(); + entry->set_key("length"); + entry->set_value(std::to_string(fbs_raw_data->size())); } else { // fbs_raw_data is uint8_t vector, so the size is byte size initializer.set_raw_data(fbs_raw_data->Data(), fbs_raw_data->size()); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 7629e40c1b5fe..be0531e6473fb 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -7,7 +7,6 @@ #include "core/flatbuffers/flatbuffers_utils.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/model.h" -#include "core/graph/model_editor_api_types.h" #include "core/graph/model_load_utils.h" #ifdef _MSC_VER @@ -739,36 +738,6 @@ Status Model::Load(int fd, const PathString& model_path, std::shared_ptr& return Status::OK(); } -// static -common::Status Model::LoadFromModelEditorApiModel(const OrtModel& model_editor_api_model, - const IOnnxRuntimeOpSchemaRegistryList* local_registries, - const ModelOptions& options, - const logging::Logger& logger, - std::unique_ptr& model) { - model = std::make_unique(); - model->model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - // The optimizer Initializer class requires a path if external data is used, however in the Graph API usage the - // external data is pointing to pre-allocated memory and does not require a path. Set a dummy value to make it happy. - model->model_path_ = std::filesystem::path("_GRAPH_API_MODEL_"); - - auto schema_registry = std::make_shared(); - if (local_registries != nullptr) { - for (const auto& schema_collection : *local_registries) { - schema_registry->RegisterRegistry(schema_collection); - } - } - - ORT_RETURN_IF_ERROR(Graph::LoadFromModelEditorApiModel(*model_editor_api_model.graph, - *model, - model_editor_api_model.domain_to_version, - schema_registry, - options.strict_shape_type_inference, - logger, - model->graph_)); - - return Status::OK(); -} - Status Model::Save(Model& model, int p_fd) { if (p_fd < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0."); @@ -948,4 +917,5 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, #endif return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 6fd94c60d6b99..2d2086aef41fd 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -280,12 +280,6 @@ class Model { const logging::Logger& logger, const ModelOptions& options = {}); - static common::Status LoadFromModelEditorApiModel(const OrtModel& graph_api_model, - const IOnnxRuntimeOpSchemaRegistryList* local_registries, - const ModelOptions& options, - const logging::Logger& logger, - std::unique_ptr& model); - common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& model) const; @@ -339,7 +333,7 @@ class Model { ModelMetaData model_metadata_; // Path to model file. May be empty. - std::filesystem::path model_path_; + const std::filesystem::path model_path_; // Main graph of the model. std::unique_ptr graph_; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h deleted file mode 100644 index d72bd13093b61..0000000000000 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/inlined_containers_fwd.h" -#include "core/framework/ort_value.h" -#include "core/framework/onnxruntime_typeinfo.h" -#include "core/graph/onnx_protobuf.h" - -// ORT C interface types for OrtGraphApi can't be in a namespace. -// We need to define them here so onnxruntime::Model can be created from OrtModel. - -struct OrtValueInfo { - std::string name; - std::unique_ptr type_info; -}; - -struct OrtOpAttr { - ONNX_NAMESPACE::AttributeProto attr_proto; -}; - -struct OrtNode { - std::string operator_name; - std::string domain_name; - std::string node_name; - - // OrtOpAttr is 1:1 with ONNX_NAMESPACE::AttributeProto currently. - // https://github.com/microsoft/onnxruntime/blob/bd5a759d0cdbed6e7f611c990d4eb5457a9ecf60/onnxruntime/core/session/standalone_op_invoker.cc#L318 - onnxruntime::InlinedVector attributes; - onnxruntime::InlinedVector input_names; - onnxruntime::InlinedVector output_names; - - // FUTURE if we need control flow nodes - // std::unordered_map subgraphs; -}; - -struct OrtGraph { - onnxruntime::InlinedVector> inputs; - onnxruntime::InlinedVector> outputs; - std::unordered_map> initializers; - std::unordered_map> external_initializers; - std::vector> nodes; -}; - -struct OrtModel { - std::unique_ptr graph; - std::unordered_map domain_to_version; -}; diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e36eef672c1ed..e755b4bfa6364 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -21,16 +21,7 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider, const ConfigOptions& config_options, const InlinedHashSet& compatible_execution_providers, const InlinedHashSet& excluded_initializers) noexcept - : ConstantFolding("ConstantFolding", execution_provider, skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers) { -} - -ConstantFolding::ConstantFolding(const std::string& name, - const IExecutionProvider& execution_provider, - bool skip_dequantize_linear, - const ConfigOptions& config_options, - const InlinedHashSet& compatible_execution_providers, - const InlinedHashSet& excluded_initializers) noexcept - : GraphTransformer(name, compatible_execution_providers), + : GraphTransformer("ConstantFolding", compatible_execution_providers), skip_dequantize_linear_(skip_dequantize_linear), config_options_(config_options), excluded_initializers_(excluded_initializers), @@ -153,7 +144,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, for (NodeIndex i : order) { auto* node = graph.GetNode(i); - if (!node || !AllowConstantFolding(*node)) { + if (!node) { continue; } diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h index 29bc67d560788..14eb2a9c5f06b 100644 --- a/onnxruntime/core/optimizer/constant_folding.h +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -28,24 +28,6 @@ class ConstantFolding : public GraphTransformer { const InlinedHashSet& compatible_execution_providers = {}, const InlinedHashSet& excluded_initializers = {}) noexcept; - protected: - /** - * Same as the constructor above but with a name provided by derived class. - */ - ConstantFolding(const std::string& name, - const IExecutionProvider& execution_provider, - bool skip_dequantize_linear, - const ConfigOptions& config_options, - const InlinedHashSet& compatible_execution_providers = {}, - const InlinedHashSet& excluded_initializers = {}) noexcept; - /** - * Derived class can implement this virtual function to limit the nodes that can be constant folded. - */ - virtual bool AllowConstantFolding(const Node& node) const { - ORT_UNUSED_PARAMETER(node); - return true; - } - private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.cc b/onnxruntime/core/optimizer/graph_optimizer_registry.cc deleted file mode 100644 index 8ede372470485..0000000000000 --- a/onnxruntime/core/optimizer/graph_optimizer_registry.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/optimizer/graph_optimizer_registry.h" -#include "core/optimizer/graph_transformer_utils.h" -#include "core/optimizer/selection_and_optimization_func.h" -#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" - -using namespace onnxruntime; -using namespace ::onnxruntime::common; - -namespace onnxruntime { -#if !defined(ORT_MINIMAL_BUILD) -GraphOptimizerRegistry::GraphOptimizerRegistry(const onnxruntime::SessionOptions* sess_options, - const onnxruntime::IExecutionProvider* cpu_ep, - const logging::Logger* logger) : session_options_(sess_options), - cpu_ep_(cpu_ep), - logger_(logger) { - auto status = CreatePredefinedSelectionFuncs(); - ORT_ENFORCE(status.IsOK(), "Could not create pre-defined selection functions. Error Message: ", - status.ErrorMessage()); -} - -Status GraphOptimizerRegistry::CreatePredefinedSelectionFuncs() { - transformer_name_to_selection_func_[kConstantFoldingDQ] = ConstantFoldingDQFuncs::Select; - - return Status::OK(); -} - -std::optional GraphOptimizerRegistry::GetSelectionFunc(std::string& name) const { - auto lookup = transformer_name_to_selection_func_.find(name); - if (lookup != transformer_name_to_selection_func_.end()) { - return transformer_name_to_selection_func_.at(name); - } - LOGS(*logger_, WARNING) << "Can't find selection function of " << name; - return std::nullopt; -} -#else -GraphOptimizerRegistry::GraphOptimizerRegistry(const onnxruntime::SessionOptions* sess_options, - const onnxruntime::IExecutionProvider* cpu_ep, - const logging::Logger* logger) : session_options_(sess_options), - cpu_ep_(cpu_ep), - logger_(logger) {} - -std::optional GraphOptimizerRegistry::GetSelectionFunc(std::string& /*name*/) const { - return std::nullopt; -} -#endif -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.h b/onnxruntime/core/optimizer/graph_optimizer_registry.h deleted file mode 100644 index 15c9287c0eac8..0000000000000 --- a/onnxruntime/core/optimizer/graph_optimizer_registry.h +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/inlined_containers.h" -#include "core/common/logging/logging.h" -#include "core/common/common.h" -#include "core/optimizer/graph_transformer.h" -#include "core/framework/execution_providers.h" -#include "core/framework/compute_capability.h" - -namespace onnxruntime { -/** - * Optimizer's selection function: Selects a set of nodes from a given graph for optimization. Additional key/value strings can be provided to configure the optimizer. - * If needed, use graph_optimizer_registry to access the session options, the CPU EP and the logger. - * - * Optimizer's optimization function: Gets the nodes in ComputeCapability from nodes_to_optimize. Use graph_optimizer_registry to access the session options, the CPU EP - * and the logger if needed to create the optimizer. Run optimization on the nodes/subgraph, and finally, update the ComputeCapability. - * - */ -using KeyValueConfig = std::unordered_map; -using SelectionFunc = std::function>(const GraphViewer&, - const KeyValueConfig&, - const GraphOptimizerRegistry& graph_optimizer_registry)>; -using OptimizationFunc = std::function; - -/** - * A registration/lookup class for re-usable optimizers for EPs. - */ -class GraphOptimizerRegistry { - public: - /** - * The constructor takes in session options, the CPU EP and a logger as these are required by some optimizers. - */ - GraphOptimizerRegistry(const onnxruntime::SessionOptions* sess_options, - const onnxruntime::IExecutionProvider* cpu_ep, - const logging::Logger* logger); - - /** - * Get optimizer selection function. If the optimizer name can't be found, return nullopt. - */ - std::optional GetSelectionFunc(std::string& name) const; - - /** - * Get CPU EP. - */ - const onnxruntime::IExecutionProvider& GetCpuEp() const { return *cpu_ep_; } - - /** - * Get Session Options. - */ - const onnxruntime::SessionOptions& GetSessionOptions() const { return *session_options_; } - - /** - * Get Logger. - */ - const logging::Logger* GetLogger() const { return logger_; } - - private: - const onnxruntime::SessionOptions* session_options_; - const onnxruntime::IExecutionProvider* cpu_ep_; - const logging::Logger* logger_; - -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - InlinedHashMap transformer_name_to_selection_func_; - - /** - * Create pre-defined selection functions. - */ - Status CreatePredefinedSelectionFuncs(); -#endif -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc deleted file mode 100644 index a2f46d6ae693c..0000000000000 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" -#include "core/optimizer/graph_optimizer_registry.h" -#include "core/graph/graph_utils.h" - -namespace onnxruntime { - -ConstantFoldingDQ::ConstantFoldingDQ(const IExecutionProvider& execution_provider, - bool skip_dequantize_linear, - const ConfigOptions& config_options, - const InlinedHashSet& node_index_set, - const InlinedHashSet& compatible_execution_providers, - const InlinedHashSet& excluded_initializers) noexcept - : ConstantFolding("ConstantFoldingDQ", execution_provider, skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers), - node_index_set_(node_index_set) {} - -bool ConstantFoldingDQ::AllowConstantFolding(const Node& node) const { - if (node_index_set_.find(node.Index()) != node_index_set_.end()) { - return true; - } - return false; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h deleted file mode 100644 index 7aed87fa06adb..0000000000000 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/graph_transformer.h" -#include "core/optimizer/constant_folding.h" -#include "core/framework/ort_value.h" -#include -#include "core/framework/execution_provider.h" - -namespace onnxruntime { - -/** -@class ConstantFoldingDQ - -It's the derived class from ConstantFolding. -*/ -class ConstantFoldingDQ : public ConstantFolding { - public: - /*! Constant folding will not be applied to nodes that have one of initializers from excluded_initializers as input. - \param execution_provider Execution provider instance to execute constant folding. - */ - ConstantFoldingDQ(const IExecutionProvider& execution_provider, - bool skip_dequantize_linear, - const ConfigOptions& config_options, - const InlinedHashSet& node_index_set, - const InlinedHashSet& compatible_execution_providers = {}, - const InlinedHashSet& excluded_initializers = {}) noexcept; - - bool AllowConstantFolding(const Node& node) const override; - - private: - InlinedHashSet node_index_set_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.cc b/onnxruntime/core/optimizer/selection_and_optimization_func.cc deleted file mode 100644 index 151c61952a631..0000000000000 --- a/onnxruntime/core/optimizer/selection_and_optimization_func.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "selection_and_optimization_func.h" -#include "core/graph/graph_utils.h" -#include "core/framework/compute_capability.h" -#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" - -namespace onnxruntime { - -std::vector> ConstantFoldingDQFuncs::Select(const GraphViewer& graph_viewer, - const KeyValueConfig& /*config*/, - const GraphOptimizerRegistry& /*graph_optimizer_registry*/) { - std::vector> result; - std::unique_ptr sub_graph = std::make_unique(); - const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED /*priority-based topological sort*/); - InitializedTensorSet constant_inputs; - const InlinedHashSet excluded_initializers; - - // Select DequantizeLinear node where all inputs are constant - for (const auto& index : node_index) { - const auto& node = graph_viewer.GetNode(index); - if (node->OpType() != "DequantizeLinear") { - continue; - } - if (!graph_utils::AllNodeInputsAreConstant(graph_viewer.GetGraph(), *node, constant_inputs, excluded_initializers)) { - continue; - } - sub_graph->nodes.push_back(index); - } - - result.push_back(std::make_unique(std::move(sub_graph))); - result.back()->optimization_func = ConstantFoldingDQFuncs::Optimize; - return result; -} - -Status ConstantFoldingDQFuncs::Optimize(Graph& graph, - const ComputeCapability& optimization_cc, - ComputeCapability& cc_to_update, - const GraphOptimizerRegistry& graph_optimizer_registry) { - std::string optimizer_name = kConstantFoldingDQ; - std::unordered_set original_initializers_to_remove; - std::unordered_set new_initializers_to_add; - InlinedHashSet dq_node_index_set; - - // iterate the nodes in node_to_optimize to: - // 1. get original initializers to remove - // 2. add new initializers - // 3. create dq node index set - for (const auto& index : optimization_cc.sub_graph->nodes) { - auto node = graph.GetNode(index); - if (node->OpType() != "DequantizeLinear") { - continue; - } - auto input_0 = node->InputDefs()[0]; - auto output_0 = node->OutputDefs()[0]; - original_initializers_to_remove.insert(input_0->Name()); - new_initializers_to_add.insert(output_0->Name()); - dq_node_index_set.insert(index); - } - - static auto transformer = std::make_unique(graph_optimizer_registry.GetCpuEp(), - false /*skip_dequantize_linear*/, - graph_optimizer_registry.GetSessionOptions().config_options, - dq_node_index_set); - - bool modified = false; - ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, *graph_optimizer_registry.GetLogger())); - - // update the overall ComputeCapability - std::vector updated_nodes; - for (auto index : cc_to_update.sub_graph->nodes) { - if (dq_node_index_set.find(index) != dq_node_index_set.end()) { - continue; - } - updated_nodes.push_back(index); - } - cc_to_update.sub_graph->nodes = updated_nodes; - - auto meta_def = cc_to_update.sub_graph->GetMutableMetaDef(); - std::vector updated_constant_initializers; - - for (auto constant_initializer : meta_def->constant_initializers) { - if (original_initializers_to_remove.find(constant_initializer) != original_initializers_to_remove.end()) { - continue; - } - updated_constant_initializers.push_back(constant_initializer); - } - - for (auto constant_initializer : new_initializers_to_add) { - updated_constant_initializers.push_back(constant_initializer); - } - - meta_def->constant_initializers = updated_constant_initializers; - - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.h b/onnxruntime/core/optimizer/selection_and_optimization_func.h deleted file mode 100644 index 6ad62518833b0..0000000000000 --- a/onnxruntime/core/optimizer/selection_and_optimization_func.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/graph_optimizer_registry.h" -#include "core/framework/compute_capability.h" -#include "core/graph/graph_viewer.h" - -namespace onnxruntime { -static const std::string kConstantFoldingDQ = "ConstantFoldingDQ"; - -/** - * Optimizer's selection function: Selects a set of nodes from a given graph for optimization. Additional key/value strings can be provided to configure the optimizer. - * If needed, use graph_optimizer_registry to access the session options, the CPU EP and the logger. - * - * Optimizer's optimization function: Gets the nodes in ComputeCapability from nodes_to_optimize. Use graph_optimizer_registry to access the session options, the CPU EP - * and the logger if needed to create the optimizer. Run optimization on the nodes/subgraph, and finally, update the ComputeCapability. - * - */ - -struct ConstantFoldingDQFuncs { - static std::vector> Select(const GraphViewer& graph_viewer, - const KeyValueConfig& configs, - const GraphOptimizerRegistry& graph_optimizer_registry); - static Status Optimize(Graph& graph, - const ComputeCapability& optimization_cc, - ComputeCapability& cc_to_update, - const GraphOptimizerRegistry& graph_optimizer_registry); -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index def1d5e4b704c..ede476ff74d1b 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -153,7 +153,6 @@ std::shared_ptr ACLExecutionProvider::GetKernelRegistry() const std::vector> ACLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant*) const { std::vector> result; for (const auto& node : graph.Nodes()) { diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.h b/onnxruntime/core/providers/acl/acl_execution_provider.h index 80e4aaaf021e3..d635e56add30b 100755 --- a/onnxruntime/core/providers/acl/acl_execution_provider.h +++ b/onnxruntime/core/providers/acl/acl_execution_provider.h @@ -39,7 +39,6 @@ class ACLExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; Status OnRunStart(const onnxruntime::RunOptions&) override; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index be09eefba791b..07e83933a890c 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1254,7 +1254,6 @@ GetSubGraphPartition(const std::vector& topological_order, const std: std::vector> CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant*) const { std::vector> result; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index f28ae77e49f83..5ff935463a1c1 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -56,7 +56,6 @@ class CANNExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; Status Compile(const std::vector& fused_nodes_and_graphs, diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index cc7beed6bb298..3fa3868267c9b 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -39,7 +39,6 @@ CoreMLExecutionProvider::~CoreMLExecutionProvider() {} std::vector> CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 574ae1fc0106b..0609bf6af726d 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -20,7 +20,6 @@ class CoreMLExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index b33b1f189594b..c65dd2a04bf55 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -244,7 +244,7 @@ static Status ConcatenateCpuOutput(void* /*stream*/, // we can't easily use a C++ template for the tensor element type, // so use a span for some protection but work in bytes - gsl::span output_span = gsl::make_span(static_cast(output), + gsl::span output_span = gsl::make_span(static_cast(output), output_size_in_bytes); for (size_t i = 0, num_iterations = per_iteration_output.size(); i < num_iterations; ++i) { @@ -257,7 +257,7 @@ static Status ConcatenateCpuOutput(void* /*stream*/, " Expected:", per_iteration_shape, " Got:", iteration_data.Shape()); } - auto src = gsl::make_span(static_cast(iteration_data.DataRaw()), + auto src = gsl::make_span(static_cast(iteration_data.DataRaw()), bytes_per_iteration); auto dst = output_span.subspan(i * bytes_per_iteration, bytes_per_iteration); gsl::copy(src, dst); diff --git a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc index f3c6b18f8e753..03b39e19ed748 100644 --- a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc +++ b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc @@ -34,18 +34,17 @@ ONNX_OPERATOR_KERNEL_EX( ConvInteger); Status ConvInteger::Compute(OpKernelContext* context) const { - const auto input_defs = Node().InputDefs(); - size_t num_inputs = input_defs.size(); + size_t num_inputs = OpKernel::Node().InputDefs().size(); const auto* X = context->Input(0); const auto* W = context->Input(1); uint8_t input_offset = 0; uint8_t filter_offset = 0; - if (num_inputs >= 3 && input_defs[2]->Exists()) { + if (num_inputs >= 3) { const auto* X_Zero_Point = context->Input(2); ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1."); input_offset = *(X_Zero_Point->Data()); } - if (num_inputs >= 4 && input_defs[3]->Exists()) { + if (num_inputs >= 4) { const auto* W_Zero_Point = context->Input(3); ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now."); filter_offset = *(W_Zero_Point->Data()); diff --git a/onnxruntime/core/providers/cuda/controlflow/loop.cc b/onnxruntime/core/providers/cuda/controlflow/loop.cc index d66de7c74e647..3295b73a800c9 100644 --- a/onnxruntime/core/providers/cuda/controlflow/loop.cc +++ b/onnxruntime/core/providers/cuda/controlflow/loop.cc @@ -84,10 +84,10 @@ static Status ConcatenateGpuOutput(void* stream, std::vector& per_iter CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cur_output, iteration_data.DataRaw(), bytes_per_iteration, cudaMemcpyDeviceToDevice, static_cast(stream))); - cur_output = static_cast((static_cast(cur_output) + bytes_per_iteration)); + cur_output = static_cast((static_cast(cur_output) + bytes_per_iteration)); } - ORT_ENFORCE(static_cast(cur_output) - static_cast(output) == output_size_in_bytes, + ORT_ENFORCE(static_cast(cur_output) - static_cast(output) == output_size_in_bytes, "Concatenation did not fill output buffer as expected."); return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 54fb4429c0536..b675c08e5f804 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2660,7 +2660,6 @@ std::unique_ptr CUDAExecutionProvider::GetDataTransf std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const { std::vector> result; const logging::Logger& logger = *GetLogger(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index a75e81f1f0c6d..79a48e7cb89e1 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -73,7 +73,6 @@ class CUDAExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; int GetDeviceId() const override { return info_.device_id; } diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index a38fe1efad540..cbf745d3c7b4f 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -290,16 +290,16 @@ Status Upsample::BaseCompute(OpKernelContext* context, scales_div[i] = fast_divmod(gsl::narrow_cast(ceil(scales[i]))); } - UpsampleImpl(Stream(context), - mode_, - rank, - (UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0, - input_strides, - output_div_pitches, - scales_div, - reinterpret_cast(X->Data()), - reinterpret_cast(Y->MutableData()), - output_count); + UpampleImpl(Stream(context), + mode_, + rank, + (UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0, + input_strides, + output_div_pitches, + scales_div, + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu b/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu index 24aeada559979..d1c2ae6332994 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu @@ -8,12 +8,12 @@ namespace onnxruntime { namespace cuda { template -__global__ void _UpsampleNearestKernel(const TArray input_pitches, - const TArray output_div_pitches, - const TArray scales_div, - const T* __restrict__ input_data, - T* __restrict__ output_data, - const size_t N) { +__global__ void _UpampleNearestKernel(const TArray input_pitches, + const TArray output_div_pitches, + const TArray scales_div, + const T* __restrict__ input_data, + T* __restrict__ output_data, + const size_t N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG input_index = 0; CUDA_LONG output_index = id; @@ -36,13 +36,13 @@ __global__ void _UpsampleNearestKernel(const TArray input_pitches, // This is the common use-case where the 4-D input (batched multi-channel images) // is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale] template -__global__ void _UpsampleBilinear4DInputKernel(const int64_t input_dim2, - const TArray input_pitches, - const TArray output_div_pitches, - const TArray scales_div, - const T* __restrict__ input_data, - T* __restrict__ output_data, - const size_t N) { +__global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2, + const TArray input_pitches, + const TArray output_div_pitches, + const TArray scales_div, + const T* __restrict__ input_data, + T* __restrict__ output_data, + const size_t N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG input_index = 0; @@ -95,13 +95,13 @@ __global__ void _UpsampleBilinear4DInputKernel(const int64_t input_dim2, // The following method supports a 2-D input in 'Linear mode' template -__global__ void _UpsampleBilinear2DInputKernel(const int64_t input_dim0, - const TArray input_pitches, - const TArray output_div_pitches, - const TArray scales_div, - const T* __restrict__ input_data, - T* __restrict__ output_data, - const size_t N) { +__global__ void _UpampleBilinear2DInputKernel(const int64_t input_dim0, + const TArray input_pitches, + const TArray output_div_pitches, + const TArray scales_div, + const T* __restrict__ input_data, + T* __restrict__ output_data, + const size_t N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG input_index = 0; @@ -147,32 +147,32 @@ __global__ void _UpsampleBilinear2DInputKernel(const int64_t input_dim0, } template -void UpsampleImpl(cudaStream_t stream, - const onnxruntime::UpsampleMode upsample_mode, - const size_t rank, - const int64_t input_dim2, - const TArray& input_pitches, - const TArray& output_div_pitches, - const TArray& scales_div, - const T* input_data, - T* output_data, - const size_t N) { +void UpampleImpl(cudaStream_t stream, + const onnxruntime::UpsampleMode upsample_mode, + const size_t rank, + const int64_t input_dim2, + const TArray& input_pitches, + const TArray& output_div_pitches, + const TArray& scales_div, + const T* input_data, + T* output_data, + const size_t N) { int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); if (onnxruntime::UpsampleMode::NN == upsample_mode) { if (rank == 4) { - _UpsampleNearestKernel<<>>( + _UpampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 3) { - _UpsampleNearestKernel<<>>( + _UpampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 2) { - _UpsampleNearestKernel<<>>( + _UpampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 1) { - _UpsampleNearestKernel<<>>( + _UpampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else { @@ -180,11 +180,11 @@ void UpsampleImpl(cudaStream_t stream, } } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode) { if (rank == 4) { - _UpsampleBilinear4DInputKernel<<>>( + _UpampleBilinear4DInputKernel<<>>( input_dim2, input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 2) { - _UpsampleBilinear2DInputKernel<<>>( + _UpampleBilinear2DInputKernel<<>>( input_dim2, input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else { @@ -197,17 +197,17 @@ void UpsampleImpl(cudaStream_t stream, } } -#define SPECIALIZED_IMPL(T) \ - template void UpsampleImpl(cudaStream_t stream, \ - const onnxruntime::UpsampleMode upsample_mode, \ - const size_t rank, \ - const int64_t input_dim2, \ - const TArray& input_pitches, \ - const TArray& output_div_pitches, \ - const TArray& scales_div, \ - const T* input_data, \ - T* output_data, \ - const size_t N); +#define SPECIALIZED_IMPL(T) \ + template void UpampleImpl(cudaStream_t stream, \ + const onnxruntime::UpsampleMode upsample_mode, \ + const size_t rank, \ + const int64_t input_dim2, \ + const TArray& input_pitches, \ + const TArray& output_div_pitches, \ + const TArray& scales_div, \ + const T* input_data, \ + T* output_data, \ + const size_t N); SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double) diff --git a/onnxruntime/core/providers/cuda/tensor/upsample_impl.h b/onnxruntime/core/providers/cuda/tensor/upsample_impl.h index fb47ad8301615..250ec6b272e34 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/upsample_impl.h @@ -11,16 +11,16 @@ namespace onnxruntime { namespace cuda { template -void UpsampleImpl(cudaStream_t stream, - const onnxruntime::UpsampleMode upsample_mode, - const size_t rank, - const int64_t input_dim2, - const TArray& input_pitches, - const TArray& output_div_pitches, - const TArray& scales_div, - const T* input_data, - T* output_data, - const size_t N); +void UpampleImpl(cudaStream_t stream, + const onnxruntime::UpsampleMode upsample_mode, + const size_t rank, + const int64_t input_dim2, + const TArray& input_pitches, + const TArray& output_div_pitches, + const TArray& scales_div, + const T* input_data, + T* output_data, + const size_t N); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 868b2103586f9..9d23b8b950272 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -93,13 +93,12 @@ namespace Dml ExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, - const onnxruntime::GraphOptimizerRegistry& graph_optimizer_registry, onnxruntime::IResourceAccountant* resource_accountant) const { #ifdef ENABLE_GRAPH_COMPILATION - return m_impl->GetCapability(graph, kernel_lookup, graph_optimizer_registry, resource_accountant, *GetLogger()); + return m_impl->GetCapability(graph, kernel_lookup, resource_accountant, *GetLogger()); #else - return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup, graph_optimizer_registry, resource_accountant); + return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup, resource_accountant); #endif } @@ -879,7 +878,6 @@ namespace Dml ExecutionProviderImpl::GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, - const onnxruntime::GraphOptimizerRegistry& /* graph_optimizer_registry */, onnxruntime::IResourceAccountant*, const onnxruntime::logging::Logger& logger) const { uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index aa3d8b0b4a409..7f420f8850001 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -13,7 +13,6 @@ namespace onnxruntime { class IResourceAccountant; -class GraphOptimizerRegistry; } namespace WRL { @@ -94,7 +93,6 @@ namespace Dml GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, - const onnxruntime::GraphOptimizerRegistry& graph_optimizer_registry, onnxruntime::IResourceAccountant* resource_accountant, const onnxruntime::logging::Logger& logger) const; @@ -290,7 +288,6 @@ namespace Dml std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, - const onnxruntime::GraphOptimizerRegistry& /* graph_optimizer_registry */, onnxruntime::IResourceAccountant* resource_accountant) const final override; onnxruntime::common::Status OnSessionInitializationEnd() override diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index d0e5b0b1588ef..4da82b351f1d6 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -147,7 +147,6 @@ std::vector> DnnlExecutionProvider::GetSupportedNodes(con std::vector> DnnlExecutionProvider::GetCapability( const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // follow from coreml ep's Getcapability diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index 8f951efef2a94..bde18e139f2a3 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -25,7 +25,6 @@ class DnnlExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, onnxruntime::IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index d8e24ff1f5053..9d00436150286 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -791,7 +791,6 @@ std::vector JsExecutionProvider::CreatePreferredAllocators() { std::vector> JsExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // `tenative_candidates` is a subset of `candidates`. diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index c87303209c689..4bead50fc782e 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -45,7 +45,6 @@ class JsExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 9a694b03387ae..1558d22137c05 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -993,7 +993,6 @@ GetPartitionedSubgraphs(const std::vector& topological_order, std::vector> MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; auto model = graph_viewer.CreateModel(*GetLogger()); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 7c89b5ec544a1..d6af991f9b77e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -69,7 +69,6 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index 28cfde817a620..27bd584e2d3c6 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -81,7 +81,6 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {} std::vector> NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; const logging::Logger& logger = *GetLogger(); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h index a2269fdd89436..ebf9372eb668d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h @@ -26,7 +26,6 @@ class NnapiExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index d026ce386e5c3..9d4ad88e2c2b3 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -647,7 +647,7 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe const auto& out_name = item.first; auto node = item.second; Ort::UnownedValue output_tensor = GetOutputTensor(context, - out_name, + std::move(out_name), subgraph_context_.output_names, node); auto mem_info = output_tensor.GetTensorMemoryInfo(); diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 6482a07ee92bc..12c16e9c9b8f6 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -107,7 +107,6 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { std::vector> OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 020aec16e507c..bbcca583b074b 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -51,7 +51,6 @@ class OpenVINOExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index d85277627a3de..3df231e53e7c0 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -198,13 +198,35 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, return Status::OK(); } +// Figure out the real context cache file path +// return true if context cache file exists +bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, + const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path) { + // always try the path set by user first, it's the only way to set it if load model from memory + if (!customer_context_cache_path.empty()) { + context_cache_path = ToPathString(customer_context_cache_path); + } else if (!model_pathstring.empty()) { // model loaded from file + if (is_qnn_ctx_model) { + // it's a context cache model, just use the model path + context_cache_path = model_pathstring; + } else if (!model_pathstring.empty()) { + // this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + } + + return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); +} + Status CreateEPContextNodes(Model* model, unsigned char* buffer, uint64_t buffer_size, const std::string& sdk_build_version, const std::vector& fused_nodes_and_graphs, const QnnModelLookupTable& qnn_models, - const onnxruntime::PathString& context_model_path, + const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, uint64_t max_spill_fill_buffer_size, const logging::Logger& logger) { @@ -240,19 +262,7 @@ Status CreateEPContextNodes(Model* model, std::string cache_payload(buffer, buffer + buffer_size); ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload); } else { - onnxruntime::PathString context_bin_path; - auto pos = context_model_path.find_last_of(ORT_TSTR(".")); - if (pos != std::string::npos) { - context_bin_path = context_model_path.substr(0, pos); - } else { - context_bin_path = context_model_path; - } - std::string graph_name_in_file(graph_name); - auto name_pos = graph_name_in_file.find_first_of(kQnnExecutionProvider); - if (name_pos != std::string::npos) { - graph_name_in_file.replace(name_pos, strlen(kQnnExecutionProvider), ""); - } - context_bin_path = context_bin_path + ToPathString(graph_name_in_file + ".bin"); + onnxruntime::PathString context_bin_path = context_cache_path + ToPathString("_" + graph_name + ".bin"); std::string context_cache_name(std::filesystem::path(context_bin_path).filename().string()); std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary); if (!of_stream) { diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index c54cd3ca6e90c..3dfa0ae21001b 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -38,6 +38,11 @@ Status CreateNodeArgs(const std::vector& names, std::vector& node_args, onnxruntime::Graph& graph); +bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, + const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path); + Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, @@ -62,7 +67,7 @@ Status CreateEPContextNodes(Model* model, const std::string& sdk_build_version, const std::vector& fused_nodes_and_graphs, const std::unordered_map>& qnn_models, - const onnxruntime::PathString& context_model_path, + const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, uint64_t max_spill_fill_buffer_size, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 26d792c008edc..bcde69beceef7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -470,10 +470,8 @@ Status QnnBackendManager::InitializeProfiling() { QnnProfile_Level_t qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; if (ProfilingLevel::BASIC == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; - LOGS_DEFAULT(VERBOSE) << "Profiling level set to basic."; } else if (ProfilingLevel::DETAILED == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_DETAILED; - LOGS_DEFAULT(VERBOSE) << "Profiling level set to detailed."; } Qnn_ErrorHandle_t result = qnn_interface_.profileCreate(backend_handle_, qnn_profile_level, &profile_backend_handle_); ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to create QNN profile! Error: ", QnnErrorHandleToString(result)); diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index cb92e927ff65a..1fb8742f724cd 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -181,9 +181,7 @@ void HtpSharedMemoryAllocator::Free(void* allocation_address) { // Avoid throwing exceptions as this may be running from a destructor. try { // take ownership of shared memory and free at end of scope - const size_t allocation_offset = AllocationOffsetFromStartOfHeader(); - void* raw_allocation_address = (void*)((std::byte*)allocation_address - allocation_offset); - auto shared_memory = WrapSharedMemoryWithUniquePtr(raw_allocation_address, rpcmem_lib_->Api()); + auto shared_memory = WrapSharedMemoryWithUniquePtr(allocation_address, rpcmem_lib_->Api()); // destroy header allocation_header.~AllocationHeader(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index a5813dc2a4adc..3fc537066ae0b 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -195,10 +195,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio share_ep_contexts_ = config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; LOGS_DEFAULT(VERBOSE) << "User specified option - share EP contexts across sessions: " << share_ep_contexts_; - - stop_share_ep_contexts_ = - config_options->GetConfigOrDefault(kOrtSessionOptionStopShareEpContexts, "0") == "1"; - LOGS_DEFAULT(VERBOSE) << "User specified option - stop share EP contexts across sessions: " << stop_share_ep_contexts_; } static const std::string BACKEND_PATH = "backend_path"; @@ -388,27 +384,17 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } - // For context binary generation with weight sharing enabled, use the QnnBackendManager from the shared context if it exits - // So that all graphs from later sessions will be compiled into the same QNN context - if (context_cache_enabled_ && share_ep_contexts_ && SharedContext::GetInstance().GetSharedQnnBackendManager()) { - qnn_backend_manager_ = SharedContext::GetInstance().GetSharedQnnBackendManager(); - // Clear the QnnBackendManager from singleton to stop the resource share - if (stop_share_ep_contexts_) { - SharedContext::GetInstance().ResetSharedQnnBackendManager(); - } - } else { - qnn_backend_manager_ = qnn::QnnBackendManager::Create( - qnn::QnnBackendManagerConfig{backend_path, - profiling_level_etw, - profiling_level, - profiling_file_path, - context_priority, - qnn_saver_path, - device_id_, - htp_arch, - soc_model, - enable_htp_weight_sharing}); - } + qnn_backend_manager_ = qnn::QnnBackendManager::Create( + qnn::QnnBackendManagerConfig{backend_path, + profiling_level_etw, + profiling_level, + profiling_file_path, + context_priority, + qnn_saver_path, + device_id_, + htp_arch, + soc_model, + enable_htp_weight_sharing}); #if defined(_WIN32) if (onnxruntime::logging::EtwRegistrationManager::SupportsETW()) { @@ -669,7 +655,6 @@ static void PartitionCtxModel(const onnxruntime::GraphViewer& graph_viewer, std::vector> QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; @@ -919,33 +904,25 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); - onnxruntime::PathString context_model_path; + onnxruntime::PathString context_cache_path; bool is_ctx_file_exist = false; if (is_qnn_ctx_model || context_cache_enabled_) { const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); - // Figure out the EP context model path from model path or session option - GetContextOnnxModelFilePath(context_cache_path_cfg_, - graph_viewer_0.ModelPath().native(), - context_model_path); + is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model, + context_cache_path_cfg_, + graph_viewer_0.ModelPath().native(), + context_cache_path); } + ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_, + "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", + "Please remove the EP context model manually if you want to re-generate it."); + if (is_qnn_ctx_model) { // Get QnnModel from EP shared contexts if (share_ep_contexts_ && SharedContext::GetInstance().HasSharedQnnModels()) { @@ -988,7 +965,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); // Create QNN context from the cached binary, deserialize the QNN graph from the binary ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, - context_model_path, + context_cache_path, qnn_backend_manager_.get(), qnn_models, logger, @@ -1048,16 +1025,10 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_backend_manager_->GetSdkVersion(), fused_nodes_and_graphs, qnn_models_, - context_model_path, + context_cache_path, qnn_context_embed_mode_, max_spill_fill_buffer_size, logger)); - - if (share_ep_contexts_ && !stop_share_ep_contexts_ && - nullptr == SharedContext::GetInstance().GetSharedQnnBackendManager()) { - ORT_RETURN_IF_NOT(SharedContext::GetInstance().SetSharedQnnBackendManager(qnn_backend_manager_), - "Failed to set shared QnnBackendManager."); - } } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index d7a5d04d22692..31c34855ca4c0 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -31,7 +31,6 @@ class QNNExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; Status Compile(const std::vector& fused_nodes_and_graphs, @@ -91,7 +90,6 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t default_rpc_control_latency_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; - bool stop_share_ep_contexts_ = false; bool enable_spill_fill_buffer_ = false; #if defined(_WIN32) onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; diff --git a/onnxruntime/core/providers/qnn/shared_context.h b/onnxruntime/core/providers/qnn/shared_context.h index 277a484ad8528..81de357dbe677 100644 --- a/onnxruntime/core/providers/qnn/shared_context.h +++ b/onnxruntime/core/providers/qnn/shared_context.h @@ -61,39 +61,13 @@ class SharedContext { return graph_exist; } - bool SetSharedQnnBackendManager(std::shared_ptr& qnn_backend_manager) { - const std::lock_guard lock(mtx_); - - if (qnn_backend_manager_ != nullptr) { - if (qnn_backend_manager_ == qnn_backend_manager) { - return true; - } - return false; - } - qnn_backend_manager_ = qnn_backend_manager; - return true; - } - - std::shared_ptr GetSharedQnnBackendManager() { - const std::lock_guard lock(mtx_); - return qnn_backend_manager_; - } - - void ResetSharedQnnBackendManager() { - const std::lock_guard lock(mtx_); - qnn_backend_manager_.reset(); - } - private: SharedContext() = default; ~SharedContext() = default; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SharedContext); - // Used for passing through QNN models (deserialized from context binary) across sessions std::vector> shared_qnn_models_; - // Used for compiling multiple models into same QNN context binary - std::shared_ptr qnn_backend_manager_; // Producer sessions can be in parallel // Consumer sessions have to be after producer sessions initialized std::mutex mtx_; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc index e9343e2b2e06a..10fd81786f977 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc @@ -51,7 +51,6 @@ std::vector> RknpuExecutionProvider::GetSupportedNodes( std::vector> RknpuExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // Find inputs, initializers and outputs for each supported subgraph std::vector> result; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h index 75cae37d117a0..ce16d63e111d9 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h @@ -20,7 +20,6 @@ class RknpuExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 49771488efc44..9d6e9df907ce3 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -2441,7 +2441,6 @@ std::unique_ptr ROCMExecutionProvider::GetDataTransf std::vector> ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 2baaf2ff1a886..ff2bff7c98723 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -62,7 +62,6 @@ class ROCMExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const override { return info_.device_id; } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 9d61e1f12f5b6..6ff2572e5e668 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -200,7 +200,6 @@ struct SparseTensor; class TensorSeq; class SessionState; class ModelMetadefIdGenerator; -class GraphOptimizerRegistry; class If; class Loop; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 90fd36ea29956..2dab9f6a402a0 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -332,9 +332,8 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) const { - return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup, graph_optimizer_registry, resource_accountant); + return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup, resource_accountant); } common::Status IExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 83d615c1bde0a..a77f0cb4c27b0 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -105,8 +105,6 @@ using ModelMetaData = std::unordered_map; using IOnnxRuntimeOpSchemaCollectionPtr = std::shared_ptr; using IOnnxRuntimeOpSchemaRegistryList = std::list; using InitializedTensorSet = std::unordered_map; -using KeyValueConfig = std::unordered_map; -using SelectionFunc = std::function>(const GraphViewer&, const KeyValueConfig&, const GraphOptimizerRegistry&)>; struct Node__NodeIterator { virtual ~Node__NodeIterator() {} @@ -153,10 +151,6 @@ struct ConstGraphNodes_Iterator { struct ProviderHost { virtual const OrtApiBase* OrtGetApiBase() = 0; - virtual Status GetOptimizerByName(const std::string& name, - const GraphOptimizerRegistry& graph_optimizer_registry, - SelectionFunc& selection_func) = 0; - virtual void* HeapAllocate(size_t size) = 0; virtual void HeapFree(void*) = 0; @@ -259,7 +253,6 @@ struct ProviderHost { // IExecutionProvider virtual std::vector> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, const IExecutionProvider::IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) = 0; virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) = 0; @@ -634,8 +627,6 @@ struct ProviderHost { virtual std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) = 0; virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0; virtual std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability* p) = 0; - virtual void ComputeCapability__copy_optimization_func(ComputeCapability* p, ComputeCapability* selection_cc) = 0; - virtual void ComputeCapability__add_nodes_to_optimize(ComputeCapability* p, std::unique_ptr optimization_cc) = 0; // DataTransferManager virtual Status DataTransferManager__CopyTensor(const DataTransferManager* p, const Tensor& src, Tensor& dst) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index e2af144f455e4..a502ce9c66f69 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -527,9 +527,6 @@ struct ComputeCapability final { std::unique_ptr& SubGraph() { return g_host->ComputeCapability__SubGraph(this); } - void copy_optimization_func(ComputeCapability* selection_cc) { g_host->ComputeCapability__copy_optimization_func(this, selection_cc); } - void add_nodes_to_optimize(std::unique_ptr optimization_cc) { g_host->ComputeCapability__add_nodes_to_optimize(this, std::move(optimization_cc)); } - ComputeCapability() = delete; ComputeCapability(const ComputeCapability&) = delete; void operator=(const ComputeCapability&) = delete; diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc index 4eae7c97f9ab0..c7fc6d3a556a7 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc @@ -72,7 +72,6 @@ SNPEExecutionProvider::~SNPEExecutionProvider() {} std::vector> SNPEExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector candidates; for (auto& node_index : graph.GetNodesInTopologicalOrder()) { diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.h b/onnxruntime/core/providers/snpe/snpe_execution_provider.h index 4b7987b38ee93..99033649fcbbf 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.h +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.h @@ -19,7 +19,6 @@ class SNPEExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 523ebbfae807a..e59d252793532 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2459,7 +2459,6 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& std::vector> TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* /* resource_accountant */) const { // Construct subgraph capability from node list std::vector> result; @@ -2665,61 +2664,11 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } } - /** - * Enable EP related L2+ graph optimizations: - * - * 1. Calls provider bridge API to lookup pre-defined optimizer by name and get selection function. - * - Example: g_host->GetOptimizerByName(optimizer_name, graph_optimizer_registry, selection_func) - * 2. Executes the selection function to obtain the selection ComputeCapability. - * - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization. - * 3. Uses the selection ComputeCapability to create the optimization ComputeCapability. - * 4. Returns the final ComputeCapability, with nodes_to_optimize set to the optimization ComputeCapability. - * - * Current available optimizations: - * - (ConstantFoldingDQ) constant folding on DQ nodes, i.e. dequantize INT32, UINT16, INT16 constant to FP32. - */ - - SelectionFunc selection_func; - std::vector> selection_cc; - - // Prepare for ConstantFoldingDQ optimizer - // Note: The NodeIndex here is the node index in the graph, not the index in node vector in supported_nodes_vector. - std::unordered_set trt_selection_node_set; // The qualified dq nodes selected by TRT EP - std::unordered_map consumer_to_dq; // consumer node -> dq node - - if (dla_enable_) { - std::string optimizer_name = "ConstantFoldingDQ"; - const std::unordered_map key_value_config; - auto status = g_host->GetOptimizerByName(optimizer_name, graph_optimizer_registry, selection_func); - if (status == Status::OK()) { - if (selection_func) { - selection_cc = selection_func(graph, key_value_config, graph_optimizer_registry); - SelectQualifiedDQNode(graph, trt_selection_node_set, consumer_to_dq); - } - } else { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Can't get optimizer " << optimizer_name; - } - } - - // Create ComputeCapability int number_of_trt_nodes = 0, subgraph_index = 0; - for (auto& group : supported_nodes_vector) { + for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { - if (!selection_cc.empty()) { - // Include DQ nodes that are filtered out by TRT parser - UpdateSupportedNodeVectorForDQ(graph, group, supported_nodes_vector, consumer_to_dq); - } - std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); - auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); - - // add optimization ComputeCapability to node_to_optimize - for (auto& cc : selection_cc) { - std::unique_ptr optimization_cc = CreateOptimizationComputeCapability(cc.get(), trt_selection_node_set, compute_capability.get()); - compute_capability->add_nodes_to_optimize(std::move(optimization_cc)); - } - - result.push_back(std::move(compute_capability)); + result.push_back(ComputeCapability::Create(std::move(sub_graph))); number_of_trt_nodes += static_cast(group.first.size()); subgraph_index++; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 934cc06eed45f..873826a81c51b 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -249,7 +249,6 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return device_id_; } @@ -593,35 +592,5 @@ class TensorrtExecutionProvider : public IExecutionProvider { * This function only creates the instance at the first time it's being called." */ nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; - - /** - * This is the helper function for ConstantFoldingDQ graph transformer. - * - * It selects the qualified/required DQ node to be optimized as well as provides a mapping table - * to help TRT EP later include the DQ node which is filtered out by TRT parser. - */ - void SelectQualifiedDQNode(const GraphViewer& graph, - std::unordered_set& selection_node_set, - std::unordered_map& consumer_to_dq) const; - - /** - * This function returns an optimization ComputeCapability that is limited to: - * 1. the DQ nodes in this individual TRT ComputeCapability - * 2. the DQ nodes that are qualified and selected by TRT EP - * - * It also needs to make sure the DQ nodes is a subset of the complete list of DQ nodes to optimize in original selection ComputeCapability. - * Finally, copy the optimization function from the original selection ComputeCapability. - */ - std::unique_ptr CreateOptimizationComputeCapability(ComputeCapability* selection_cc, - std::unordered_set& trt_selection_node_set, - ComputeCapability* trt_cc) const; - /** - * This function helps add back the DQ nodes that are filtered out by TRT parser. - * The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization. - */ - void UpdateSupportedNodeVectorForDQ(const GraphViewer& graph, - SubGraph_t& supported_node_vector, - SubGraphCollection_t& supported_nodes_vector, - std::unordered_map consumer_to_dq) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index 71674f7c9c557..92fa101118506 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -258,133 +258,4 @@ void TensorrtExecutionProvider::SetAllGraphInputs(Graph& graph) const { graph.SetInputs(graph_inputs_including_initializers); } - -/** - * This is the helper function for ConstantFoldingDQ graph transformer. - * - * It selects the qualified/required DQ node to be optimized as well as provides a mapping table - * to help TRT EP later include the DQ node which is filtered out by TRT parser. - */ -void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, - std::unordered_set& selection_node_set, - std::unordered_map& consumer_to_dq) const { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Select qualified DQ nodes ..."; - const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); - for (auto index : node_index) { - auto* node = graph.GetNode(index); - if (!node) { - continue; - } - - const auto* input_def = node->InputDefs()[0]; // Get NodeArg of the initializer of the DequantizeLinear node; - auto data_type = input_def->TypeAsProto()->tensor_type().elem_type(); - auto constant_initializer = graph.IsConstantInitializer(input_def->Name(), true); - - // Node selection: (i.e. initializer -> DQ -> bias of X) - // 1. DequantizeLinear op - // 2. DQ node does not produce graph output, single consumer - // 3. The first input of DQ is constant initializer. - // 4. The data type of initializer is INT32, UINT16 or INT16 - // 5. X should be Gemm, Conv or LayerNormalization ? - if (node->OpType() == "DequantizeLinear" && - node->GetOutputEdgesCount() == 1 && - (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16 || data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) && - constant_initializer) { - const Node& consumer_node = *node->OutputNodesBegin(); - selection_node_set.insert(index); - consumer_to_dq[consumer_node.Index()] = index; - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " <- " << node->Name(); - } - } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Total " << selection_node_set.size() << " DequantizeLinear node(s) are selected."; -} - -/** - * This function returns an optimization ComputeCapability that is limited to: - * 1. the DQ nodes in this individual TRT ComputeCapability - * 2. the DQ nodes that are qualified and selected by TRT EP - * - * It also needs to make sure the DQ nodes is a subset of the complete list of DQ nodes to optimize in original selection ComputeCapability. - * Finally, copy the optimization function from the original selection ComputeCapability. - */ -std::unique_ptr TensorrtExecutionProvider::CreateOptimizationComputeCapability(ComputeCapability* selection_cc, - std::unordered_set& trt_selection_node_set, - ComputeCapability* trt_cc) const { - auto sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_set selection_node_set; - - for (auto index : selection_cc->SubGraph()->Nodes()) { - selection_node_set.insert(index); - } - - for (auto index : trt_cc->SubGraph()->Nodes()) { - if (selection_node_set.find(index) == selection_node_set.end()) { - continue; - } - if (trt_selection_node_set.find(index) == trt_selection_node_set.end()) { - continue; - } - sub_graph->Nodes().push_back(index); - } - auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); - compute_capability->copy_optimization_func(selection_cc); - return compute_capability; -} - -/** - * This function helps add back the DQ nodes that are filtered out by TRT parser. - * The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization. - */ -void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph, - SubGraph_t& supported_node_vector, - SubGraphCollection_t& supported_nodes_vector, - std::unordered_map consumer_to_dq) const { - if (consumer_to_dq.empty()) { - return; - } - - if (!supported_node_vector.second) { - return; - } - - const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); - auto supported_nodes = supported_node_vector.first; - for (auto index : supported_nodes) { - if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) { - continue; - } - - auto dq_node_index = consumer_to_dq[node_index[index]]; - - // Check if DQ node is included in one of the subgraphs - auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool { - for (auto& node_vector : supported_nodes_vector) { - if (!node_vector.second) { - continue; - } - for (auto i : node_vector.first) { - if (node_index[i] == node_idx) { - return true; - } - } - } - return false; - }; - - // If the DQ node is already in the subgraph, do nothing. - if (in_the_subgraph_collection(dq_node_index)) { - continue; - } - - // Find the iterator pointing to the target element - auto it = std::find(node_index.begin(), node_index.end(), dq_node_index); - if (it != node_index.end()) { - // Calculate the index - size_t idx = std::distance(node_index.begin(), it); - supported_node_vector.first.push_back(idx); - auto node = graph.GetNode(dq_node_index); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser."; - } - } -} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index ab8a95b38491d..5d2204b0b1979 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -51,7 +51,7 @@ const InlinedVector VitisAIExecutionProvider::GetEpContextNodes() c return ep_context_node_ptrs; } std::vector> VitisAIExecutionProvider::GetCapability( - const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { + const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, IResourceAccountant* /* resource_accountant */) const { if (graph_viewer.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index f72f8cc721fbd..5b031ab882839 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -29,7 +29,6 @@ class VitisAIExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return 0; } diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 3b5daef04dd50..4b9f6fae86423 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -62,7 +62,6 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {} std::vector> VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h index 1c0b8b63a8e6c..16cfbc8a9c581 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h @@ -40,7 +40,6 @@ class VSINPUExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; Status Compile(const std::vector& fused_nodes_and_graphs, diff --git a/onnxruntime/core/providers/webgpu/external_data_loader.cc b/onnxruntime/core/providers/webgpu/external_data_loader.cc deleted file mode 100644 index 6da9598b146f5..0000000000000 --- a/onnxruntime/core/providers/webgpu/external_data_loader.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#if defined(__wasm__) - -#include - -#include "core/framework/tensor.h" -#include "core/providers/webgpu/external_data_loader.h" - -namespace onnxruntime { -namespace webgpu { - -bool ExternalDataLoader::CanLoad(const OrtMemoryInfo& target_memory_info) const { - return target_memory_info.device.Type() == OrtDevice::CPU || - (target_memory_info.device.Type() == OrtDevice::GPU && target_memory_info.name == WEBGPU_BUFFER); -} - -common::Status ExternalDataLoader::LoadTensor(const Env& env, - const std::filesystem::path& data_file_path, - FileOffsetType data_offset, - SafeInt data_length, - Tensor& tensor) const { - ExternalDataLoadType load_type; - if (tensor.Location().device.Type() == OrtDevice::CPU) { - load_type = ExternalDataLoadType::CPU; - } else if (tensor.Location().device.Type() == OrtDevice::GPU && - tensor.Location().name == WEBGPU_BUFFER) { - load_type = ExternalDataLoadType::WEBGPU_BUFFER; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: ", tensor.Location().ToString()); - } - - return LoadWebAssemblyExternalData(env, data_file_path, data_offset, data_length, load_type, tensor.MutableDataRaw()); -} - -} // namespace webgpu -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/webgpu/external_data_loader.h b/onnxruntime/core/providers/webgpu/external_data_loader.h deleted file mode 100644 index 7ced4e930bf7a..0000000000000 --- a/onnxruntime/core/providers/webgpu/external_data_loader.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#if defined(__wasm__) - -#include "core/framework/external_data_loader.h" - -namespace onnxruntime { -namespace webgpu { - -class ExternalDataLoader : public IExternalDataLoader { - public: - ExternalDataLoader() {}; - ~ExternalDataLoader() {}; - - bool CanLoad(const OrtMemoryInfo& target_memory_info) const override; - - common::Status LoadTensor(const Env& env, - const std::filesystem::path& data_file_path, - FileOffsetType data_offset, - SafeInt data_length, - Tensor& tensor) const override; -}; - -} // namespace webgpu -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/webgpu/generator/range.cc b/onnxruntime/core/providers/webgpu/generator/range.cc index 99c5a1c1b5566..a0b65f08a5b4e 100644 --- a/onnxruntime/core/providers/webgpu/generator/range.cc +++ b/onnxruntime/core/providers/webgpu/generator/range.cc @@ -23,7 +23,7 @@ Status Range::ComputeInternal(ComputeContext& context) const { return Status::OK(); } - uint32_t output_size = onnxruntime::narrow(n); + uint32_t output_size = gsl::narrow(n); RangeProgram program{}; #if defined(__GNUC__) #pragma GCC diagnostic push diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 8a22e45f17047..75866513e2c7d 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -141,7 +141,7 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { } } - uint32_t vec_size = onnxruntime::narrow((size + 3) / 4); + uint32_t vec_size = gsl::narrow((size + 3) / 4); BinaryElementwiseProgram program{kernel_name_, expression_, is_broadcast, diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc deleted file mode 100644 index d06fc5a57eb8c..0000000000000 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "core/common/inlined_containers.h" -#include "core/providers/common.h" -#include "core/providers/webgpu/math/softmax.h" -#include "core/providers/webgpu/tensor/transpose.h" -#include "core/providers/cpu/tensor/utils.h" -#include "core/providers/webgpu/shader_variable.h" -#include "core/providers/webgpu/shader_helper.h" -#include "core/providers/webgpu/webgpu_supported_types.h" -namespace onnxruntime { -namespace webgpu { - -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Softmax, - kOnnxDomain, - 1, 10, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", WebGpuSupportedNumberTypes()), - Softmax); - -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Softmax, - kOnnxDomain, - 11, 12, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", WebGpuSupportedNumberTypes()), - Softmax); - -ONNX_OPERATOR_KERNEL_EX( - Softmax, - kOnnxDomain, - 13, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", WebGpuSupportedNumberTypes()), - Softmax); - -static std::string MaxVector(const std::string& name, int components) { - switch (components) { - case 1: - return name; - case 2: - return "max(" + name + ".x, " + name + ".y)"; - case 3: - return "max(max(" + name + ".x, " + name + ".y), " + name + ".z)"; - case 4: - return "max(max(" + name + ".x, " + name + ".y), max(" + name + ".z, " + name + ".w))"; - default: - ORT_THROW("Unsupported number of components: ", components); - } -} - -static std::string SumVector(const std::string& x, int components) { - switch (components) { - case 1: - return x; - case 2: - return "(" + x + ".x + " + x + ".y" + ")"; - case 4: - return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; - default: - ORT_THROW("Unsupported number of components: ", components); - } -} - -static int GetMaxComponents(int64_t size) { - if (size % 4 == 0) { - return 4; - } else if (size % 2 == 0) { - return 2; - } - return 1; -} - -Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { - // Add input and output variables - const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - int components = input.NumComponents(); - - const std::string thread_max_decl = is_fp32_ - ? "var thread_max = x_value_t(-3.402823e+38f);\n" - : "var thread_max = x_value_t(-65504.0h);\n"; - - // Define shared memory for row max and row sum - shader.AdditionalImplementation() - << "var row_max_shared : x_value_t;\n" - << "var row_sum_shared : x_value_t;\n" - << "var thread_shared : array;\n"; - - // Define helper functions to get and set values - shader.AdditionalImplementation() - << "fn getValue(row: i32, col: i32, row_stride: i32) -> x_value_t {\n" - << " let index = row * row_stride + col;\n" - << " return x[index];\n" - << "}\n" - << "fn setValue(row: i32, col: i32, row_stride: i32, value: x_value_t) {\n" - << " let index = row * row_stride + col;\n" - << " result[index] = value;\n" - << "}\n"; - - // Main function body - shader.MainFunctionBody() - << " let gindex = i32(global_idx);\n" - << " let lindex = i32(local_idx);\n" - << " const wg = " << wg_ << ";\n" - << " let row = gindex / wg;\n" - << " let cols = uniforms.packedCols;\n" - << " let row_stride : i32 = uniforms.packedCols;\n" - - // Find the row's max value - << thread_max_decl - << " for (var col = lindex; col < cols; col += wg) {\n" - << " let value = getValue(row, col, row_stride);\n" - << " thread_max = max(thread_max, value);\n" - << " }\n" - << " if (lindex < cols) {\n" - << " thread_shared[lindex] = thread_max;\n" - << " }\n" - << " workgroupBarrier();\n" - - // Reduce to find the max value - << " var reduce_size = min(cols, wg);\n" - << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" - << " reduce_size = curr_size + (reduce_size & 1);\n" - << " if (lindex < curr_size) {\n" - << " thread_shared[lindex] = max(thread_shared[lindex], thread_shared[lindex + reduce_size]);\n" - << " }\n" - << " workgroupBarrier();\n" - << " }\n" - << " if (lindex == 0) {\n" - << " row_max_shared = x_value_t(" << MaxVector("thread_shared[0]", components) << ");\n" - << " }\n" - << " workgroupBarrier();\n" - - // Find the row's sum of exponentials - << " var thread_sum = x_value_t(0.0);\n" - << " for (var col = lindex; col < cols; col += wg) {\n" - << " let sub_exp = exp(getValue(row, col, row_stride) - row_max_shared);\n" - << " thread_sum += sub_exp;\n" - << " }\n" - << " thread_shared[lindex] = thread_sum;\n" - << " workgroupBarrier();\n" - - // Reduce to find the sum of exponentials - << " for (var curr_size = wg >> 1; curr_size > 0; curr_size = curr_size >> 1) {\n" - << " if (lindex < curr_size) {\n" - << " thread_shared[lindex] = thread_shared[lindex] + thread_shared[lindex + curr_size];\n" - << " }\n" - << " workgroupBarrier();\n" - << " }\n" - << " if (lindex == 0) {\n" - << " row_sum_shared = x_value_t(" << SumVector("thread_shared[0]", components) << ");\n" - << " }\n" - << " workgroupBarrier();\n" - - // Calculate the final value for each element in the row - << " for (var col = lindex; col < cols; col += wg) {\n" - << " let value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n" - << " setValue(row, col, row_stride, value);\n" - << " }\n"; - - return Status::OK(); -} - -Status Softmax::ComputeInternal(ComputeContext& context) const { - const auto* input_tensor = context.Input(0); - const TensorShape& input_shape = input_tensor->Shape(); - size_t input_rank = input_shape.NumDimensions(); - auto* output_tensor = context.Output(0, input_shape); - - // normalize axis - size_t axis = static_cast(HandleNegativeAxis(axis_, input_rank)); - bool is_transpose_required = axis < input_rank - 1; - - TensorShape transposed_input_shape; - Tensor transposed_input_tensor; - Tensor intermediate_output; - InlinedVector perm(input_rank); - - if (is_transpose_required) { - std::iota(std::begin(perm), std::end(perm), 0); - perm[axis] = input_rank - 1; - perm[input_rank - 1] = axis; - - TensorShapeVector transposed_input_dims; - for (auto e : perm) { - transposed_input_dims.push_back(input_shape[e]); - } - - transposed_input_shape = TensorShape(transposed_input_dims); - transposed_input_tensor = context.CreateGPUTensor(input_tensor->DataType(), transposed_input_shape); - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, *input_tensor, transposed_input_tensor)); - intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape); - } - - const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1]; - const int64_t rows = input_shape.Size() / cols; - const int64_t components = GetMaxComponents(cols); - const auto packed_cols = cols / components; - uint32_t workgroup_size = rows == 1 ? 256 : 64; - // check input tensor element type is float - const bool is_fp32 = input_tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - - SoftmaxProgram program{workgroup_size, is_fp32}; - if (is_transpose_required) { - program - .AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}) - .AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}); - } else { - program - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}); - } - - program - .CacheHint(std::to_string(components), std::to_string(workgroup_size)) - .SetWorkgroupSize(workgroup_size) - .SetDispatchGroupSize(static_cast(rows)) - .AddUniformVariables({{static_cast(packed_cols)}}); - - ORT_RETURN_IF_ERROR(context.RunProgram(program)); - - // If transpose was required, transpose the result back - if (is_transpose_required) { - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, intermediate_output, *output_tensor)); - } - - return Status::OK(); -} -} // namespace webgpu -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/softmax.h b/onnxruntime/core/providers/webgpu/math/softmax.h deleted file mode 100644 index cc97611dcb4bc..0000000000000 --- a/onnxruntime/core/providers/webgpu/math/softmax.h +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/webgpu/webgpu_supported_types.h" -#include "core/providers/webgpu/webgpu_kernel.h" -#include "core/providers/webgpu/program.h" -#include "core/framework/op_kernel.h" - -namespace onnxruntime { -namespace webgpu { - -class Softmax final : public WebGpuKernel { - public: - Softmax(const OpKernelInfo& info) : WebGpuKernel{info} { - int opset_ = info.node().SinceVersion(); - int64_t axis; - Status status = info.GetAttr("axis", &axis); - - if (status.IsOK()) { - axis_ = axis; - } else { - if (opset_ < 13) { - axis_ = 1; // opset-12 and below, the default axis value is 1 - } else { - axis_ = -1; // opset-13, the default axis value is -1 - } - } - } - - Status ComputeInternal(ComputeContext& context) const override; - - private: - int64_t axis_; -}; - -class SoftmaxProgram final : public Program { - public: - SoftmaxProgram(uint32_t wg, bool is_fp32) - : Program{"Softmax"}, wg_{wg}, is_fp32_{is_fp32} { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32}); - - private: - uint32_t wg_; - bool is_fp32_; -}; - -} // namespace webgpu -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 189d7baafce6a..eaaad206ebaf5 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -27,7 +27,7 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { if (size == 0) { return Status::OK(); } - uint32_t vec_size = onnxruntime::narrow((size + 3) / 4); + uint32_t vec_size = gsl::narrow((size + 3) / 4); UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program .AddInputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 28ad686909a47..64172021e82f1 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -23,7 +23,7 @@ static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { if (axis < -rank && axis >= rank) { ORT_THROW("invalid axis: ", axis); } - return onnxruntime::narrow(axis < 0 ? axis + rank : axis); + return gsl::narrow(axis < 0 ? axis + rank : axis); } static std::string SumVector(std::string x, int components) { @@ -92,10 +92,10 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); - const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis)); + const uint32_t norm_count = gsl::narrow(x_shape.SizeToDimension(axis)); const int64_t norm_size = x_shape.SizeFromDimension(axis); const int components = GetMaxComponents(norm_size); - const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); + const uint32_t norm_size_vectorized = gsl::narrow((norm_size + components - 1) / components); const auto scale_size = scale->Shape().Size(); const auto bias_size = (bias) ? bias->Shape().Size() : 0; diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 976b7927ac3dd..d1d4c242c4697 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -206,26 +206,6 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp } } -std::ostream& operator<<(std::ostream& os, ValidationMode mode) { - switch (mode) { - case ValidationMode::Disabled: - os << "Disabled"; - break; - case ValidationMode::WGPUOnly: - os << "WGPUOnly"; - break; - case ValidationMode::Basic: - os << "Basic"; - break; - case ValidationMode::Full: - os << "Full"; - break; - default: - os << "Unknown(" << static_cast(mode) << ")"; - } - return os; -} - namespace { TensorShape GetReducedShape(const TensorShape& shape, int component /* > 1 */) { ORT_ENFORCE(shape.NumDimensions() > 0 && shape.GetDims()[shape.NumDimensions() - 1] % component == 0, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 95fef36144025..7bfd9e8800099 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -237,7 +237,6 @@ enum class ValidationMode { Basic, Full }; -std::ostream& operator<<(std::ostream& os, ValidationMode mode); namespace details { class ProgramWrapper; diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index 7a4a873a1adf3..1fdd312d4f0d8 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -24,14 +24,14 @@ Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint auto limit_per_dimension = limits_.maxComputeWorkgroupsPerDimension; if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) { - double size = static_cast(x) * static_cast(y) * static_cast(z); - double dispatch_avg = std::ceil(std::sqrt(size)); + auto size = static_cast(x) * static_cast(y) * static_cast(z); + uint32_t dispatch_avg = gsl::narrow(std::ceil(std::sqrt(size))); if (dispatch_avg > limit_per_dimension) { - dispatch_avg = std::ceil(std::cbrt(size)); + dispatch_avg = gsl::narrow(std::ceil(std::cbrt(size))); ORT_RETURN_IF(dispatch_avg > limit_per_dimension, "The dispatch group size exceeds WebGPU maximum."); - x = y = z = static_cast(dispatch_avg); + x = y = z = dispatch_avg; } else { - x = y = static_cast(dispatch_avg); + x = y = dispatch_avg; z = 1; } } diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc deleted file mode 100644 index eb7903e7903b6..0000000000000 --- a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/webgpu/reduction/reduction_ops.h" -#include -#include "core/framework/data_transfer_manager.h" -#include "core/providers/webgpu/data_transfer.h" -#include "core/providers/webgpu/shader_helper.h" -#include "core/providers/webgpu/webgpu_supported_types.h" - -namespace onnxruntime { -namespace webgpu { - -#define REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceOp, begin, end) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - ReduceOp, \ - kOnnxDomain, \ - begin, end, \ - kWebGpuExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()), \ - ReduceOp); - -#define REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceOp, version) \ - ONNX_OPERATOR_KERNEL_EX( \ - ReduceOp, \ - kOnnxDomain, \ - version, \ - kWebGpuExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()).InputMemoryType(OrtMemTypeCPUInput, 1), \ - ReduceOp); - -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMean, 18); - -Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - bool reduce_on_all_axes = no_op_with_empty_axes_ == false && axes_.empty(); - std::string loop_header = code_[0]; - std::string loop_body = "let current_element: input_value_t = " + input.GetByIndices("input_indices") + ";\n" + code_[1]; - std::string loop_footer = code_[2]; - const auto input_rank = input.Rank(); - for (int i = 0, l = 0; i < input_rank; ++i) { - if (reduce_on_all_axes || std::find(axes_.begin(), axes_.end(), i) != axes_.end()) { - if (keepdims_) { - l++; - } - std::stringstream ss; - std::string index = "i" + std::to_string(i); - ss << "for (var " << index << " : u32 = 0; " << index << " < " << input.IndicesGet("uniforms.input_shape", i) << "; " << index << "++) {\n"; - ss << input.IndicesSet("input_indices", i, index) << ";\n"; - ss << loop_body << "\n"; - ss << "}\n"; - loop_body = ss.str(); - } else { - std::stringstream ss; - ss << loop_header << "\n"; - std::string index = "i" + std::to_string(i); - ss << "let " << index << " = " << output.IndicesGet("output_indices", l) << ";\n"; - ss << input.IndicesSet("input_indices", i, index) << ";\n"; - loop_header = ss.str(); - l++; - } - } - std::stringstream input_indices_init_value; - for (int i = 0; i < input_rank - 1; ++i) { - input_indices_init_value << "0, "; - } - input_indices_init_value << "0"; - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << "let output_indices: output_indices_t = " << output.OffsetToIndices("global_idx") << ";\n" - << "var input_indices: input_indices_t = input_indices_t(" << input_indices_init_value.str() << ");\n" - << loop_header << loop_body << loop_footer; - shader.MainFunctionBody() << output.SetByOffset("global_idx", "output_value"); - return Status::OK(); -} - -template -Status ReduceKernel::ComputeInternal(ComputeContext& context) const { - const auto* input_tensor = context.Input(0); - InlinedVector input_axes; - auto rank = input_tensor->Shape().NumDimensions(); - auto transform_axis = [rank](int64_t axis) { - if (axis < 0) { - axis += rank; - } - if (axis < 0 || static_cast(axis) >= rank) { - ORT_THROW("Axes values must be in the range [-rank, rank-1]. Got: ", axis); - } - return static_cast(axis); - }; - // Check if axes input is provided and copy the axes values to input_axes - if (context.InputCount() > 1) { - ORT_ENFORCE(axes_.empty(), "Axes attribute may not be specified when axes input is also provided."); - const Tensor* axes_tensor = context.Input(1); - auto size = static_cast(axes_tensor->Shape()[0]); - const auto* data = axes_tensor->Data(); - input_axes.reserve(size); - std::transform(data, data + size, std::back_inserter(input_axes), transform_axis); - } else { - input_axes.reserve(axes_.size()); - std::transform(axes_.begin(), axes_.end(), std::back_inserter(input_axes), transform_axis); - } - if (input_axes.empty()) { - if (noop_with_empty_axes_ || rank == 0) { - // If axes is empty and noop_with_empty_axes_ is true, it is a no-op according to the spec - // If input tensor is a scalar, return the input tensor as is. - // This is not correct for ReduceLogSum and ReduceSumSquare - // TODO handle these cases separately. - auto output = context.Output(0, input_tensor->Shape()); - if (output->DataRaw() != input_tensor->DataRaw()) { - ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output)); - } - return Status::OK(); - } else { - // If axes is empty and noop_with_empty_axes_ is false, it is a reduction over all axes - input_axes.resize(rank); - std::iota(input_axes.begin(), input_axes.end(), 0); - } - } - const auto code = GetOpSpecificCode(input_tensor, input_axes.size()); - // Compute output shape - std::vector output_shape; - for (size_t i = 0; i < input_tensor->Shape().NumDimensions(); ++i) { - if (std::find(input_axes.begin(), input_axes.end(), i) != input_axes.end()) { - if (keepdims_) { - output_shape.push_back(1); - } - } else { - output_shape.push_back(input_tensor->Shape()[i]); - } - } - TensorShape output_tensor_shape(output_shape); - int64_t output_size = output_tensor_shape.Size(); - ReduceKernelProgram program("ReduceMean", keepdims_, noop_with_empty_axes_, input_axes, code); - program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) - .AddOutput({context.Output(0, output_shape), ProgramTensorMetadataDependency::TypeAndRank}) - .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({{static_cast(output_size)}, - {static_cast(noop_with_empty_axes_ ? 1 : 0)}, - {input_axes}, - {static_cast(input_axes.size())}}); - - return context.RunProgram(program); -} - -ReduceOpSpecificCode ReduceMean::GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const { - const TensorShape& input_shape = input_tensor->Shape(); - size_t input_rank = input_shape.NumDimensions(); - std::stringstream ss; - ss << "var size: u32 = 1;\n" - << "for (var i: u32 = 0; i < uniforms.axes_size; i += 1) { \n" - << " let index = " << GetElementAt("uniforms.axes", "i", axes_size) << ";\n" - << " size = size * " << GetElementAt("uniforms.input_shape", "index", input_rank) << ";\n" - << "}\n" - << "let output_value = output_value_t(sum / f32(size));"; - ReduceOpSpecificCode code({"var sum = f32(0);", "sum += f32(current_element);", ss.str()}); - return code; -} - -Status ReduceMean::ComputeInternal(ComputeContext& ctx) const { - return ReduceKernel::ComputeInternal(ctx); -} - -} // namespace webgpu -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h deleted file mode 100644 index e93eb06f20886..0000000000000 --- a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/optional.h" -#include "core/providers/webgpu/webgpu_supported_types.h" -#include "core/providers/webgpu/webgpu_kernel.h" -#include "core/providers/cpu/reduction/reduction_kernel_base.h" -#include "core/providers/webgpu/program.h" -#include "core/providers/webgpu/shader_helper.h" -namespace onnxruntime { -namespace webgpu { -// reduceOpSpecificCode is a 3-element array of strings that represent the op specific code for the reduce operation. -// The first element is the loop header, the second element is the loop body, and the third element is the loop footer. -// The loop header is the code that is executed before the loop starts. The loop body is the code that is executed for each element in the loop. -// The loop footer is the code that is executed after the loop ends. -typedef std::array ReduceOpSpecificCode; -class ReduceKernelProgram final : public Program { - public: - ReduceKernelProgram(std::string name, bool keepdims, bool no_op_with_empty_axes, const InlinedVector& axes, ReduceOpSpecificCode code) : Program{name}, keepdims_(keepdims), no_op_with_empty_axes_(no_op_with_empty_axes), axes_(axes.begin(), axes.end()), code_(code) {} - Status GenerateShaderCode(ShaderHelper& wgpuShaderModuleAddRef) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, - {"no_op_with_empty_axes", ProgramUniformVariableDataType::Uint32}, - {"axes", ProgramUniformVariableDataType::Uint32}, - {"axes_size", ProgramUniformVariableDataType::Uint32}); - - private: - const bool keepdims_; - const bool no_op_with_empty_axes_; - InlinedVector axes_; - ReduceOpSpecificCode code_; -}; - -template -class ReduceKernel : public WebGpuKernel, public ReduceKernelBase { - protected: - using ReduceKernelBase::axes_; - using ReduceKernelBase::noop_with_empty_axes_; - using ReduceKernelBase::keepdims_; - using ReduceKernelBase::select_last_index_; - - ReduceKernel(const OpKernelInfo& info, std::string name, optional keepdims_override = {}) - : WebGpuKernel(info), - ReduceKernelBase(info, keepdims_override), - name_(name) { - } - Status ComputeInternal(ComputeContext& ctx) const; - virtual ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const = 0; - - private: - std::string name_; -}; - -class ReduceMean final : public ReduceKernel { - public: - ReduceMean(const OpKernelInfo& info) : ReduceKernel(info, "ReduceMean") {} - ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const override; - Status ComputeInternal(ComputeContext& ctx) const override; -}; - -} // namespace webgpu -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 19cab9b178b1f..8fccbacac903b 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -345,6 +345,9 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha })) { ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); ss << "enable f16;\n"; + if (device_.HasFeature(wgpu::FeatureName::SubgroupsF16)) { + ss << "enable subgroups_f16;\n"; + } } if (device_.HasFeature(wgpu::FeatureName::Subgroups)) { ss << "enable subgroups;\n"; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index f8e1e0b3b8d2b..5e5920f582251 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -91,7 +91,7 @@ ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableD : name_(name), type_(type), num_components_{NumberOfComponents(type)}, - rank_{static_cast(dims.NumDimensions())}, + rank_{gsl::narrow(dims.NumDimensions())}, dims_{dims}, usage_(usage), indices_type_{GetIndicesType(rank_)}, diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 7f92ea4ed3776..8b5bede34e6d0 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -69,7 +69,7 @@ Status Cast::ComputeInternal(ComputeContext& context) const { if (size == 0) { return Status::OK(); } - uint32_t vec_size = onnxruntime::narrow((size + 3) / 4); + uint32_t vec_size = gsl::narrow((size + 3) / 4); CastProgram program{to_}; program diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.h b/onnxruntime/core/providers/webgpu/tensor/cast.h index 925cd200f0aba..ef5c4d5d0dabe 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.h +++ b/onnxruntime/core/providers/webgpu/tensor/cast.h @@ -26,7 +26,7 @@ class Cast final : public WebGpuKernel { int64_t to; Status status = info.GetAttr("to", &to); ORT_ENFORCE(status.IsOK(), "Attribute to is not set."); - to_ = onnxruntime::narrow(to); + to_ = gsl::narrow(to); // ignore attribute 'saturate' as float8 is not supported in WebGPU } diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 5cfd6c78f8929..5ed8099fde05e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -104,7 +104,7 @@ Status Concat::ComputeInternal(ComputeContext& context) const { return Status::OK(); } - uint32_t output_size = onnxruntime::narrow(prepare.output_tensor->Shape().Size()); + uint32_t output_size = gsl::narrow_cast(prepare.output_tensor->Shape().Size()); size_t axis = static_cast(prepare.axis); ConcatProgram program{axis}; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 9bdebe2c1e0d3..809616660aa9e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -42,7 +42,7 @@ Status Expand::ComputeInternal(ComputeContext& context) const { : 1; const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4 : 1; - uint32_t data_size = onnxruntime::narrow(output_shape.Size() / components_o); + uint32_t data_size = gsl::narrow(output_shape.Size() / components_o); ExpandProgram program{}; program diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc index 39d07991f3c5a..9f6e5f2420d86 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -42,7 +42,7 @@ Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { Status Gather::ComputeInternal(ComputeContext& context) const { Prepare p; ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); - uint32_t data_size = onnxruntime::narrow(p.output_tensor->Shape().Size()); + uint32_t data_size = gsl::narrow(p.output_tensor->Shape().Size()); if (data_size == 0) { return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.cc b/onnxruntime/core/providers/webgpu/tensor/pad.cc deleted file mode 100644 index 6a8bc6554b772..0000000000000 --- a/onnxruntime/core/providers/webgpu/tensor/pad.cc +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include - -#include "core/util/math.h" -#include "core/providers/webgpu/tensor/pad.h" -#include "core/providers/webgpu/shader_helper.h" -#include "core/providers/webgpu/webgpu_supported_types.h" - -namespace onnxruntime { -namespace webgpu { - -Status PadProgram::GenerateShaderCode(ShaderHelper& shader) const { - if (!dim_value_zero_) { - shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); - } - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias); - - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"); - std::string constant_value_str = std::string("let constant_value = ") + - (is_float16_ ? "bitcast>(uniforms.constant_value)[0];\n" : "bitcast(uniforms.constant_value);\n"); - if (dim_value_zero_) { - // Only Constant mode needs fill output if the one dim value or mores dims' values of input are zero. - shader.MainFunctionBody() << constant_value_str - << "output[global_idx] = constant_value;\n"; - return Status::OK(); - } - - shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" - << " var input_index = u32(0);\n" - << " var use_pad_value = false;\n" - << " var in_coord = i32(0);\n"; - - const int rank = output.Rank(); - std::string output_indices_str = "i32(" + GetElementAt("output_indices", "dim", rank) + ")"; - std::string lower_pads_str = GetElementAt("uniforms.lower_pads", "dim", rank); - std::string data_shape_str = "i32(" + GetElementAt("uniforms.data_shape", "dim", rank) + ")"; - std::string data_stride_str = rank == 1 ? "" : " * " + GetElementAt("uniforms.data_stride", "dim", rank - 1); - std::string begin_axis_statement = "in_coord = "; - std::string end_axis_statement = "in_coord = "; - std::string in_axis_statement = "in_coord = " + output_indices_str + " - " + lower_pads_str + ";\n"; - switch (mode_) { - case Mode::Constant: - begin_axis_statement = "use_pad_value = true;\n"; - end_axis_statement = "use_pad_value = true;\n"; - break; - case Mode::Edge: - begin_axis_statement += "0;\n"; - end_axis_statement += data_shape_str + " - 1;\n"; - break; - case Mode::Reflect: - begin_axis_statement += lower_pads_str + " - " + output_indices_str + ";\n"; - end_axis_statement += data_shape_str + " - 2 - (" + output_indices_str + - " - (" + lower_pads_str + " + " + data_shape_str + "));\n"; - break; - case Mode::Wrap: - begin_axis_statement += data_shape_str + " + " + output_indices_str + " - " + lower_pads_str + ";\n"; - end_axis_statement += output_indices_str + " - " + lower_pads_str + " - " + data_shape_str + ";\n"; - break; - default: - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported mode type: ", static_cast(mode_)); - } - - shader.MainFunctionBody() << " for (var dim = 0; dim < " << rank << " && !use_pad_value; dim++) {\n" - << " if (" << output_indices_str << " < " << lower_pads_str << ") {\n" - << " " << begin_axis_statement << " }\n" - << " else if (" << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n" - << " " << end_axis_statement << " }\n" - << " else {\n" - << " " << in_axis_statement << " }\n" - << " input_index += select(u32(in_coord)" << data_stride_str << ", u32(in_coord), dim == " << rank - 1 << ");\n" - << " }\n" - << " " << constant_value_str - << " " << output.SetByOffset("global_idx", "select(data[input_index], constant_value, use_pad_value)"); - - return Status::OK(); -} - -Status Pad::ComputeInternal(ComputeContext& context) const { - const Tensor* input_tensor = context.Input(0); - auto const& input_shape = input_tensor->Shape(); - size_t dimension_count = input_shape.NumDimensions(); - - const PadsVector* p_pads = &pads_; - const PadsVector* p_slices = &slices_; - - PadsVector pads; - PadsVector slices; - // kOnnxDomain Pad opset >= 11 (Or) kMsDomain opset == 1 - if (is_dynamic_) { - size_t data_rank = input_tensor->Shape().NumDimensions(); - - const Tensor* pads_tensor = context.Input(1); - auto pads_tensor_dims = pads_tensor->Shape().GetDims(); - ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), - "Pads tensor should be a 1D tensor of shape [2 * num_axes] " - "or a 2D tensor of shape [1, 2 * num_axes]"); - - const auto pads_data = pads_tensor->DataAsSpan(); - - // Compute Pads by applying axes if specified otherwise copy the supplied pads. - PadBase::ComputePads(context.KernelContext(), data_rank, pads_data, pads); - - // Separate out any negative pads into the slices array - PadBase::SeparateNegativeToSlices(pads, slices); - - p_pads = &pads; - p_slices = &slices; - } - - auto output_dims(input_shape.AsShapeVector()); - ORT_ENFORCE(dimension_count * 2 == p_pads->size(), "'pads' attribute has wrong number of values"); - - // Calculate output dimensions, and handle any negative padding - std::vector lower_pads(dimension_count); - for (size_t i = 0; i < dimension_count; i++) { - int64_t lower_pad = (*p_pads)[i] + (*p_slices)[i]; - int64_t upper_pad = (*p_pads)[i + dimension_count] + (*p_slices)[i + dimension_count]; - lower_pads[i] = static_cast(lower_pad); - output_dims[i] += lower_pad + upper_pad; - } - TensorShape output_shape(output_dims); - - // special case when there is a dim value of 0 in the shape. behavior depends on mode - bool dim_value_zero = input_shape.Size() == 0; - if (dim_value_zero) { - ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode_, input_shape, output_shape)); - } - - auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = onnxruntime::narrow(output_shape.Size()); - if (output_size == 0) { - // Do not need to fill output, return - return Status::OK(); - } - - // Read constant value and bitcast to uint32. - uint32_t value_uint32 = 0; - const auto data_type = input_tensor->GetElementType(); - bool is_float16 = data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; - const Tensor* value_tensor = context.Input(2); - if (!is_dynamic_) { - if (is_float16) { - uint16_t value = math::floatToHalf(value_); - std::memcpy(&value_uint32, &value, sizeof(value)); - } else { - value_uint32 = *reinterpret_cast(&value_); - } - } else if (value_tensor) { - ORT_ENFORCE(value_tensor->DataType() == input_tensor->DataType() && value_tensor->Shape().Size() == 1, - "Value tensor should be a 1D tensor of size 1 with the same type as that of the input tensor"); - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - int32_t value = value_tensor->Data()[0]; - value_uint32 = *reinterpret_cast(&value); - } break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - float value = value_tensor->Data()[0]; - value_uint32 = *reinterpret_cast(&value); - } break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { - uint16_t value = value_tensor->Data()[0].val; - std::memcpy(&value_uint32, &value, sizeof(value)); - } break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { - value_uint32 = value_tensor->Data()[0]; - } break; - default: - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input type: ", static_cast(data_type)); - } - } - - PadProgram program{mode_, dim_value_zero, is_float16}; - if (!dim_value_zero) { - program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}); - } - program.AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank}) - .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .CacheHint(std::to_string(static_cast(mode_)), dim_value_zero) - .AddUniformVariables({{gsl::span(lower_pads.data(), lower_pads.size())}, {output_size}, {value_uint32}}); - - return context.RunProgram(program); -} - -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Pad, - kOnnxDomain, - 2, 10, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", WebGpuSupportedNumberTypes()), - Pad); -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Pad, - kOnnxDomain, - 11, 12, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .InputMemoryType(OrtMemTypeCPUInput, 1) - .InputMemoryType(OrtMemTypeCPUInput, 2) - .TypeConstraint("T", WebGpuSupportedNumberTypes()), - Pad); -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Pad, - kOnnxDomain, - 13, 17, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .InputMemoryType(OrtMemTypeCPUInput, 1) - .InputMemoryType(OrtMemTypeCPUInput, 2) - .TypeConstraint("T", WebGpuSupportedNumberTypes()), - Pad); -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Pad, - kOnnxDomain, - 18, 18, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .InputMemoryType(OrtMemTypeCPUInput, 1) - .InputMemoryType(OrtMemTypeCPUInput, 2) - .InputMemoryType(OrtMemTypeCPUInput, 3) - .TypeConstraint("T", WebGpuSupportedNumberTypes()), - Pad); -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Pad, - kOnnxDomain, - 19, 20, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .InputMemoryType(OrtMemTypeCPUInput, 1) - .InputMemoryType(OrtMemTypeCPUInput, 2) - .InputMemoryType(OrtMemTypeCPUInput, 3) - .TypeConstraint("T", WebGpuSupportedNumberTypes()), - Pad); -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Pad, - kOnnxDomain, - 21, 22, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .InputMemoryType(OrtMemTypeCPUInput, 1) - .InputMemoryType(OrtMemTypeCPUInput, 2) - .InputMemoryType(OrtMemTypeCPUInput, 3) - .TypeConstraint("T", WebGpuSupportedNumberTypes()), - Pad); -ONNX_OPERATOR_KERNEL_EX( - Pad, - kOnnxDomain, - 23, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .InputMemoryType(OrtMemTypeCPUInput, 1) - .InputMemoryType(OrtMemTypeCPUInput, 2) - .InputMemoryType(OrtMemTypeCPUInput, 3) - .TypeConstraint("T", WebGpuSupportedNumberTypes()), - Pad); - -} // namespace webgpu -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.h b/onnxruntime/core/providers/webgpu/tensor/pad.h deleted file mode 100644 index 58049ddb0e5ce..0000000000000 --- a/onnxruntime/core/providers/webgpu/tensor/pad.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/webgpu/program.h" -#include "core/providers/webgpu/webgpu_kernel.h" -#include "core/providers/cpu/tensor/padbase.h" - -namespace onnxruntime { -namespace webgpu { - -class PadProgram final : public Program { - public: - PadProgram(const Mode mode, bool dim_value_zero, bool is_float16) : Program{"Pad"}, - mode_{mode}, - dim_value_zero_{dim_value_zero}, - is_float16_{is_float16} {} - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"lower_pads", ProgramUniformVariableDataType::Int32}, - {"output_size", ProgramUniformVariableDataType::Uint32}, - {"constant_value", ProgramUniformVariableDataType::Uint32}); - - private: - Mode mode_; - bool dim_value_zero_; - bool is_float16_; -}; - -class Pad final : public PadBase, public WebGpuKernel { - public: - Pad(const OpKernelInfo& info) : PadBase(info), WebGpuKernel(info) {} - - Status ComputeInternal(ComputeContext& context) const override; -}; - -} // namespace webgpu -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc b/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc index f68ace3c1d8a1..455e7dc54bf1d 100644 --- a/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc +++ b/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc @@ -211,7 +211,7 @@ Status ResizeNearestImpl(ComputeContext& context, onnxruntime::ResizeNearestMode nearest_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = onnxruntime::narrow(output_shape.Size()); + uint32_t output_size = gsl::narrow(output_shape.Size()); ResizeNearestProgram program{coordinate_transform_mode, nearest_mode, extrapolation_enabled, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) @@ -299,7 +299,7 @@ Status ResizeBilinearImpl(ComputeContext& context, onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = onnxruntime::narrow(output_shape.Size()); + uint32_t output_size = gsl::narrow(output_shape.Size()); ResizeBilinearProgram program{coordinate_transform_mode, extrapolation_enabled, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) @@ -413,7 +413,7 @@ Status ResizeTrilinearImpl(ComputeContext& context, onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = onnxruntime::narrow(output_shape.Size()); + uint32_t output_size = gsl::narrow(output_shape.Size()); ResizeTrilinearProgram program{coordinate_transform_mode, extrapolation_enabled, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) @@ -534,7 +534,7 @@ Status ResizeBiCubicImpl(ComputeContext& context, onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = onnxruntime::narrow(output_shape.Size()); + uint32_t output_size = gsl::narrow(output_shape.Size()); ResizeBiCubicProgram program{coordinate_transform_mode, extrapolation_enabled, exclude_outside, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) diff --git a/onnxruntime/core/providers/webgpu/tensor/split.cc b/onnxruntime/core/providers/webgpu/tensor/split.cc index d93b75fa21c16..83bf832cc5b11 100644 --- a/onnxruntime/core/providers/webgpu/tensor/split.cc +++ b/onnxruntime/core/providers/webgpu/tensor/split.cc @@ -107,7 +107,7 @@ Status Split::ComputeInternal(ComputeContext& context) const { ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes)); - SplitProgram program{static_cast(axis)}; + SplitProgram program{gsl::narrow_cast(axis)}; program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank}); auto output_dimensions = input_shape.AsShapeVector(); @@ -120,7 +120,7 @@ Status Split::ComputeInternal(ComputeContext& context) const { program.AddOutput({output, ProgramTensorMetadataDependency::Rank}); } - uint32_t input_size = onnxruntime::narrow(input_shape.Size()); + uint32_t input_size = gsl::narrow(input_shape.Size()); // Early return if the input tensor is empty. if (input_size == 0) { return Status::OK(); @@ -130,7 +130,7 @@ Status Split::ComputeInternal(ComputeContext& context) const { std::vector sizes_in_split_axis; // sizes_in_split_axis are the cumulative sizes of the splits in the split axis. for (auto split_size : split_sizes) { - previous_sum += onnxruntime::narrow(split_size); + previous_sum += gsl::narrow(split_size); sizes_in_split_axis.push_back(previous_sum); } diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 0df7d1ae9fa2f..c40ec43dd0009 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -47,10 +47,7 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); -auto SqueezeShape(const gsl::span& shape, - const gsl::span& adjusted_perm, - TensorShapeVector& new_shape, - TensorShapeVector& new_perm) { +auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] != 1) { new_shape.push_back(shape[i]); @@ -100,28 +97,26 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, - gsl::span permutations, - const Tensor& input, Tensor& output) { - const auto& input_shape = input.Shape(); - const auto& input_dims = input_shape.GetDims(); - int32_t rank = static_cast(input_shape.NumDimensions()); +Status Transpose::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); TensorShapeVector output_dims(rank); + InlinedVector default_perm(rank); + const InlinedVector* p_perm = nullptr; + ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); - for (int32_t i = 0; i < rank; i++) { - output_dims[i] = input_dims[permutations[i]]; - } - - TensorShapeVector new_shape{}; - TensorShapeVector new_perm{}; - SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm); - const bool channels_last = new_perm == TensorShapeVector({2, 3, 1}); - const bool channels_first = new_perm == TensorShapeVector({3, 1, 2}); + InlinedVector new_shape{}; + InlinedVector new_perm{}; + SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm); + const bool channels_last = new_perm == InlinedVector({2, 3, 1}); + const bool channels_first = new_perm == InlinedVector({3, 1, 2}); const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; auto new_input_shape = input_shape; TensorShape new_output_shape(output_dims); - if (use_shared) { new_input_shape = channels_last ? TensorShape({new_shape[0], new_shape[1] * new_shape[2]}) @@ -131,16 +126,16 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); } - uint32_t output_size = onnxruntime::narrow(input_shape.Size()); - TransposeProgram program{permutations, use_shared}; - + uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); + TransposeProgram program{*p_perm, use_shared}; if (use_shared) { program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); } + program - .CacheHint(absl::StrJoin(permutations, "-")) - .AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) - .AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) + .CacheHint(absl::StrJoin(*p_perm, "-")) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) .SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) .AddUniformVariables({ @@ -153,20 +148,5 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, return context.RunProgram(program); } -Status Transpose::ComputeInternal(ComputeContext& context) const { - const auto* input_tensor = context.Input(0); - const TensorShape& input_shape = input_tensor->Shape(); - int32_t rank = static_cast(input_shape.NumDimensions()); - - TensorShapeVector output_dims(rank); - InlinedVector default_perm(rank); - const InlinedVector* p_perm = nullptr; - ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); - TensorShape output_shape(output_dims); - auto* output_tensor = context.Output(0, output_shape); - - return DoTranspose(context, *p_perm, *input_tensor, *output_tensor); -} - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index b62a419fa12bc..7cf5c1fe0865d 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -16,8 +16,6 @@ class Transpose final : public WebGpuKernel, public TransposeBase { Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { } Status ComputeInternal(ComputeContext& context) const override; - static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output); - constexpr static uint32_t TILE_SIZE = 16; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index d7272ec525296..e8cdabb9dbe40 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -127,7 +127,7 @@ Status Where::ComputeInternal(ComputeContext& context) const { ORT_RETURN_IF_ERROR(ComputeOutputShape(cond_shape, x_shape, y_shape, output_shape)); auto* output_tensor = context.Output(0, output_shape); constexpr int component = 4; - uint32_t vec_size = onnxruntime::narrow((output_shape.Size() + 3) / component); + uint32_t vec_size = gsl::narrow_cast((output_shape.Size() + 3) / component); const auto is_broadcast = !(x_shape == y_shape && y_shape == cond_shape); WhereProgram program{is_broadcast}; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 97144573dde2d..163dd691b7f16 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -134,8 +134,6 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device."); } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP Context is created for: Instance=" << instance_.Get() << ", Device=" << device_.Get() << "."; - // cache adapter info ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_)); // cache device limits @@ -167,6 +165,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi #if defined(ENABLE_PIX_FOR_WEBGPU_EP) // set pix frame generator pix_frame_generator_ = std::make_unique(instance_, + Adapter(), Device()); #else ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)"); @@ -322,9 +321,9 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { std::vector dims(expected_rank); std::vector stride(expected_rank - 1); for (size_t j = 0; j < expected_rank; ++j) { - dims[j] = onnxruntime::narrow(shape[j]); + dims[j] = gsl::narrow(shape[j]); if (j < expected_rank - 1) { - stride[j] = onnxruntime::narrow(shape.SizeFromDimension(j + 1)); + stride[j] = gsl::narrow(shape.SizeFromDimension(j + 1)); } } @@ -491,7 +490,8 @@ std::vector WebGpuContext::GetAvailableRequiredFeatures(const #endif wgpu::FeatureName::TimestampQuery, wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::Subgroups}; + wgpu::FeatureName::Subgroups, + wgpu::FeatureName::SubgroupsF16}; for (auto feature : features) { if (adapter.HasFeature(feature)) { required_features.push_back(feature); @@ -708,46 +708,45 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co WGPUInstance instance = config.instance; WGPUDevice device = config.device; - std::call_once(init_default_flag_, [ + if (context_id == 0) { + // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. + ORT_ENFORCE(instance == nullptr && device == nullptr, + "WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device."); + + std::call_once(init_default_flag_, [ #if !defined(__wasm__) - dawn_proc_table = config.dawn_proc_table + dawn_proc_table = config.dawn_proc_table #endif - ]() { - // Step.1 - setup dawn proc table (only for non-WASM build) + ]() { + // Step.1 - setup dawn proc table (only for non-WASM build) #if !defined(__wasm__) - const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); + const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); #if defined(BUILD_DAWN_MONOLITHIC_LIBRARY) - ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); + ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); #else #if !defined(USE_EXTERNAL_DAWN) - if (dawn_procs == nullptr) { - dawn_procs = &dawn::native::GetProcs(); - } + if (dawn_procs == nullptr) { + dawn_procs = &dawn::native::GetProcs(); + } #else - ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); + ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); #endif - dawnProcSetProcs(dawn_procs); + dawnProcSetProcs(dawn_procs); #endif #endif - // Step.2 - Create wgpu::Instance + // Step.2 - Create wgpu::Instance #if !defined(__wasm__) - wgpu::InstanceDescriptor instance_desc{}; - instance_desc.capabilities.timedWaitAnyEnable = true; - default_instance_ = wgpu::CreateInstance(&instance_desc); + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.capabilities.timedWaitAnyEnable = true; + default_instance_ = wgpu::CreateInstance(&instance_desc); #else - default_instance_ = wgpu::CreateInstance(nullptr); + default_instance_ = wgpu::CreateInstance(nullptr); #endif - ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance."); - }); - - if (context_id == 0) { - // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. - ORT_ENFORCE(instance == nullptr && device == nullptr, - "WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device."); - + ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance."); + }); instance = default_instance_.Get(); } else { // for context ID > 0, user must provide custom WebGPU instance and device. @@ -801,9 +800,5 @@ void CleanupWebGpuContexts() { WebGpuContextFactory::Cleanup(); } -WGPUDevice GetDevice(int context_id) { - return WebGpuContextFactory::GetContext(context_id).Device().Get(); -} - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index df7f2d6dcdeab..d44cf4674d8a3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -23,7 +23,6 @@ #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/data_transfer.h" -#include "core/providers/webgpu/external_data_loader.h" #include "core/providers/webgpu/webgpu_profiler.h" namespace onnxruntime { @@ -364,9 +363,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Pad); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Pad); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Pad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If); @@ -519,10 +516,10 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -628,9 +625,9 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -688,13 +685,11 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -765,7 +760,6 @@ std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { std::vector> WebGpuExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // `tenative_candidates` is a subset of `candidates`. @@ -827,12 +821,6 @@ std::unique_ptr WebGpuExecutionProvider::GetDataTran return std::make_unique(context_); } -#if defined(__wasm__) -std::unique_ptr WebGpuExecutionProvider::GetExternalDataLoader() const { - return std::make_unique(); -} -#endif - WebGpuExecutionProvider::~WebGpuExecutionProvider() { WebGpuContextFactory::ReleaseContext(context_id_); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index e2e23b6a307cf..7a0ade97aa3df 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -45,14 +45,10 @@ class WebGpuExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; -#if defined(__wasm__) - std::unique_ptr GetExternalDataLoader() const override; -#endif DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc index 9b287b7b7df99..90b99b7b38bb1 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc @@ -11,7 +11,7 @@ namespace onnxruntime { namespace webgpu { -WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Device device) { +WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Adapter adapter, wgpu::Device device) { // Trivial window size for surface texture creation and provide frame concept for PIX. static constexpr uint32_t kWidth = 512u; static constexpr uint32_t kHeight = 512u; @@ -32,7 +32,7 @@ WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu:: wgpu::TextureFormat format; wgpu::SurfaceCapabilities capabilities; - surface_.GetCapabilities(device.GetAdapter(), &capabilities); + surface_.GetCapabilities(adapter, &capabilities); format = capabilities.formats[0]; wgpu::SurfaceConfiguration config; diff --git a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h index 0d9393321284d..52a7459a81eba 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h @@ -41,7 +41,7 @@ namespace webgpu { // WebGpuContext destruction. class WebGpuPIXFrameGenerator { public: - WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Device device); + WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Adapter adapter, wgpu::Device device); ~WebGpuPIXFrameGenerator(); void GeneratePIXFrame(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 1d779152f91f3..60c61b2ca5665 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -151,12 +151,6 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( validation_mode, }; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP Device ID: " << context_id; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUInstance: " << webgpu_instance; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUDevice: " << webgpu_device; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP DawnProcTable: " << dawn_proc_table; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP ValidationMode: " << validation_mode; - // // STEP.3 - prepare parameters for WebGPU context initialization. // diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index 966deb14196dd..cbaff79f4fd4f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -219,17 +219,9 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build sign_buffer.set(0, -1.0f); sign_buffer.set(1, 1.0f); } else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - if (model_builder.IsFloat16ArrayAvailable()) { - // Float16Array is avaliable - use Float16Array. - sign_buffer = emscripten::val::global("Float16Array").new_(2); - sign_buffer.set(0, -1.0f); - sign_buffer.set(1, 1.0f); - } else { - // Float16Array is not available - use Uint16Array instead. - sign_buffer = emscripten::val::global("Uint16Array").new_(2); - sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f)); - sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f)); - } + sign_buffer = emscripten::val::global("Uint16Array").new_(2); + sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f)); + sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input data type: ", input_data_type); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index cf4ce216ed5b3..ace6519a1fc11 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -197,8 +197,7 @@ Status ModelBuilder::RegisterInitializers() { // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached // buffers in JS side. Simply create a copy to fix it. - view = view.call("slice"); - operand = wnn_builder_.call("constant", desc, view["buffer"]); + operand = wnn_builder_.call("constant", desc, view.call("slice")); } } else { // TODO: support other type. @@ -351,8 +350,7 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( emscripten::val operand = emscripten::val::object(); // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached // buffers in JS side. Simply create a copy to fix it. - view = view.call("slice"); - operand = wnn_builder_.call("constant", desc, view["buffer"]); + operand = wnn_builder_.call("constant", desc, view.call("slice")); AddOperand(name, operand); mem_persist_buffers_.push_back(std::move(persist_buffer)); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 1e5f859506d6b..4e2d84f481df0 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -30,7 +30,6 @@ class ModelBuilder { Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; // Accessors for members. - bool IsFloat16ArrayAvailable() const { return is_float16array_available_; } const GraphViewer& GetGraphViewer() const { return graph_viewer_; } InitializedTensorSet GetInitializerTensors(); @@ -69,8 +68,6 @@ class ModelBuilder { private: const GraphViewer& graph_viewer_; const logging::Logger& logger_; - const bool is_float16array_available_ = !emscripten::val::global("Float16Array").isUndefined() && - emscripten::val::global("Float16Array").hasOwnProperty("from"); emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); @@ -175,12 +172,9 @@ const emscripten::val& ModelBuilder::CreateOrGetConstant(const int32_t& data_typ } break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - buffer = is_float16array_available_ - ? emscripten::val::global("Float16Array").new_(num_elements) - : emscripten::val::global("Uint16Array").new_(num_elements); + buffer = emscripten::val::global("Uint16Array").new_(num_elements); if (value) { - buffer.call("fill", - emscripten::val(is_float16array_available_ ? value : PackFloat32ToUint16AsFloat16(value))); + buffer.call("fill", emscripten::val(PackFloat32ToUint16AsFloat16(value))); } break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 7410ff66add30..39e6520e3912b 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -56,7 +56,6 @@ WebNNExecutionProvider::~WebNNExecutionProvider() {} std::vector> WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_registries*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its // ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index b8775e717668a..e806dc340d53e 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -25,7 +25,6 @@ class WebNNExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_registries*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; DataLayout GetPreferredLayout() const override { return preferred_layout_; } diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index ab14c083884d3..641f8b0729d0a 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -258,7 +258,6 @@ static void AddComputeCapabilityForEachNodeInNodeUnit( std::vector> XnnpackExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { const auto& logger = *GetLogger(); std::vector> capabilities; diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h index 9c4d2484f9f4b..152bef1a1c52c 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h @@ -33,7 +33,6 @@ class XnnpackExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 2e733f67a888c..7ef23d6c9e895 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -1,18 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include - -#include "core/common/inlined_containers.h" -#include "core/framework/error_code_helper.h" #include "core/graph/onnx_protobuf.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/inference_session.h" +#include "core/common/inlined_containers.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" -#include "core/session/utils.h" +#include "core/framework/error_code_helper.h" +#include +#include +#include +#include "core/session/inference_session.h" +#include "abi_session_options_impl.h" +#include "api_utils.h" OrtSessionOptions::~OrtSessionOptions() = default; diff --git a/onnxruntime/core/session/api_utils.cc b/onnxruntime/core/session/api_utils.cc new file mode 100644 index 0000000000000..f7cb8520b1e5d --- /dev/null +++ b/onnxruntime/core/session/api_utils.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "api_utils.h" + +onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { + const size_t str_len = str.size(); + const size_t req_size = str_len + 1; + + if (out == nullptr) { // User is querying the total output buffer size + *size = req_size; + return onnxruntime::common::Status::OK(); + } + + if (*size >= req_size) { // User provided a buffer of sufficient size + std::memcpy(out, str.data(), str_len); + out[str_len] = '\0'; + *size = req_size; + return onnxruntime::common::Status::OK(); + } + + // User has provided a buffer that is not large enough + *size = req_size; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); +} diff --git a/onnxruntime/core/session/api_utils.h b/onnxruntime/core/session/api_utils.h new file mode 100644 index 0000000000000..27c2bbd66f8d5 --- /dev/null +++ b/onnxruntime/core/session/api_utils.h @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include + +onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size); diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index f583767346d88..8492391172133 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -20,7 +20,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "core/session/allocator_adapters.h" -#include "core/session/utils.h" +#include "core/session/api_utils.h" #include "core/session/custom_ops.h" #include "core/session/inference_session.h" #include "core/session/ort_apis.h" @@ -900,14 +900,13 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops) { // The function registers the first schema assuming all the other one are the same except the types constraints. ORT_ENFORCE(ops.size() > 0, "No kernels to registers."); - int num_inputs_with_dynamic_type = 0; + int undefined = 0; // Creation of the schema for the first kernel in ops. const OrtCustomOp* op = *ops.begin(); ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "custom op registered at runtime", 0); - auto create_type_constraint = [&ops, &schema, &num_inputs_with_dynamic_type]( - const OrtCustomOp* op, int count, int i, bool is_input) { + auto create_type_constraint = [&ops, &schema, &undefined](const OrtCustomOp* op, int count, int i, bool is_input) { onnx::OpSchema::FormalParameterOption option = onnx::OpSchema::FormalParameterOption::Single; bool is_homogeneous = true; int min_arity = 1; @@ -977,9 +976,7 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect } else { // all_types is empty. As mentioned in the previous loop, all types are allowed. schema.TypeConstraint(name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types"); - if (is_input) { - ++num_inputs_with_dynamic_type; - } + undefined++; } }; @@ -988,21 +985,19 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect create_type_constraint(op, static_cast(input_count), static_cast(i), true); } - const bool have_shape_infer_fn = op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn; - const size_t output_count = op->GetOutputTypeCount(op); for (size_t i = 0; i < output_count; i++) { const auto type = op->GetOutputType(op, i); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { if (op->GetOutputCharacteristic(op, i) == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED) { - // if there's a dynamically typed input and output we infer they both have the same type from the input. - // if that isn't the case the user must provide an output shape inference fn which must set the output type. - ORT_ENFORCE(num_inputs_with_dynamic_type == 1 || have_shape_infer_fn, - "The type of a dynamically typed output can be inferred from a single dynamically typed input, " - "or by a user provided OrtCustomOp->InferOutputShapeFn that sets the output type."); + ORT_ENFORCE(1 == undefined, + "There must be one (and only one) dynamic typed input to the custom op. " + "Its type info at runtime will be used to infer the type info of this dynamic typed output " + "which is required for the success of the model loading step. " + "More than one dynamic typed inputs are currently not supported as differing types at runtime " + "means the output type cannot be inferred without which model loading cannot proceed."); } } - create_type_constraint(op, static_cast(output_count), static_cast(i), false); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e5ea562ce3535..a1903898ea7f0 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -38,11 +38,9 @@ #include "core/framework/utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" -#include "core/graph/model_editor_api_types.h" #include "core/graph/model_saving_options.h" #include "core/optimizer/graph_transformer_utils.h" #include "core/optimizer/graph_transformer.h" -#include "core/optimizer/graph_optimizer_registry.h" #include "core/optimizer/layout_transformation/layout_transformation.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" @@ -69,11 +67,11 @@ #include "core/optimizer/stft_decomposition.h" #endif #include "core/session/environment.h" +#include "core/session/user_logging_sink.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" -#include "core/session/user_logging_sink.h" #include "core/util/protobuf_parsing_utils.h" #include "core/util/thread_utils.h" @@ -1217,56 +1215,6 @@ common::Status InferenceSession::Load() { return LoadWithLoader(loader, "model_loading_from_saved_proto"); } -common::Status InferenceSession::Load(const OrtModel& model_editor_api_model) { - std::lock_guard l(session_mutex_); - - if (is_model_loaded_) { // already loaded - Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model."); - LOGS(*session_logger_, ERROR) << status.ErrorMessage(); - return status; - } - - if (is_inited_) { - Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session has already been initialized."); - LOGS(*session_logger_, ERROR) << status.ErrorMessage(); - return status; - } - - const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( - kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; - - // need to go from unique_ptr to shared_ptr when moving into model_ - std::unique_ptr tmp_model; - ORT_RETURN_IF_ERROR(Model::LoadFromModelEditorApiModel(model_editor_api_model, - HasLocalSchema() ? &custom_schema_registries_ : nullptr, - ModelOptions(true, strict_shape_type_inference), - *session_logger_, tmp_model)); - - model_ = std::move(tmp_model); - - is_model_loaded_ = true; - - return Status::OK(); -} - -common::Status InferenceSession::ApplyUpdates(const OrtModel& model_editor_api_model) { - std::lock_guard l(session_mutex_); - - if (!is_model_loaded_) { - Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session does not contain a loaded model."); - LOGS(*session_logger_, ERROR) << status.ErrorMessage(); - return status; - } - - if (is_inited_) { - Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session has already been initialized."); - LOGS(*session_logger_, ERROR) << status.ErrorMessage(); - return status; - } - - return model_->MainGraph().UpdateUsingModelEditorApiModel(model_editor_api_model); -} - common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format) { // The transformer order: // 1. Ensure we inline as many functions as possible. We refer to it as Ahead Of Time (AOT) function inlining. @@ -1279,13 +1227,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // 6. insert cast nodes (required transformer). // 7. insert copy nodes (required transformer). - // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup - auto graph_optimizer_registry = std::make_unique(&session_options_, - execution_providers_.Get(onnxruntime::kCpuExecutionProvider), - session_logger_); - GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry)); - // Run Ahead Of time function inlining + GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_); if (const bool disable_aot_function_inlining = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsDisableAheadOfTimeFunctionInlining, "0") == "1"; @@ -1688,7 +1631,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, SessionState& session_state, - const SessionOptions& sess_options, + const ConfigOptions& config_options, const logging::Logger& logger) { layout_transformation::TransformLayoutFunction transform_layout_fn = nullptr; @@ -1706,16 +1649,11 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup - auto graph_optimizer_registry = std::make_unique(&sess_options, - providers.Get(onnxruntime::kCpuExecutionProvider), - &logger); - - GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(kernel_registry_manager, providers); ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, - sess_options.config_options, + config_options, logger, GraphPartitioner::Mode::kOrtFormatLoad)); @@ -2158,7 +2096,7 @@ common::Status InferenceSession::Initialize() { #endif // !defined(ORT_MINIMAL_BUILD) } else { ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, - *session_state_, session_options_, *session_logger_)); + *session_state_, session_options_.config_options, *session_logger_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); @@ -3398,10 +3336,6 @@ common::Status InferenceSession::WaitForNotification(Notification* p_executor_do return Status::OK(); } -const Model& InferenceSession::GetModel() const { - return *model_; -} - SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) { ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK()); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 5b484103c9ecf..2c0c09dfd3e51 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -47,9 +47,6 @@ namespace ONNX_NAMESPACE { class ModelProto; } // namespace ONNX_NAMESPACE -// OrtModelEditorApi Model. Used to dynamically construct a model via C API at runtime. -struct OrtModel; - namespace onnxruntime { // forward declarations class CustomRegistry; class Environment; @@ -323,27 +320,6 @@ class InferenceSession { * @return OK if success. */ [[nodiscard]] common::Status Load(); - - /** - * Load an OrtModel that was dynamically constructed via OrtModelEditorApi. - * - * @param graph_api_model OrtModel from OrtModelEditorApi - * @return OK if success. - */ - [[nodiscard]] common::Status Load(const OrtModel& graph_api_model); - - /** - * Apply updates from an OrtModel that was created via OrtModelEditorApi. - * This can: - * - add nodes at the start and end of the model - * - add initializers - * - update the graph inputs/outputs - * - * @param graph_api_model OrtModel from OrtModelEditorApi - * @return OK if success. - */ - [[nodiscard]] common::Status ApplyUpdates(const OrtModel& graph_api_model); - #endif // !defined(ORT_MINIMAL_BUILD) /** @@ -595,8 +571,6 @@ class InferenceSession { #endif - const Model& GetModel() const; - protected: #if !defined(ORT_MINIMAL_BUILD) @@ -653,12 +627,6 @@ class InferenceSession { /// convenience pointer to logger. should always be the same as session_state_.Logger(); const logging::Logger* session_logger_; - // The list of execution providers. - // This MUST be prior to model_ in case there are values in the model that were allocated using an allocator - // provided by the EP. If that is the case the allocator's `free` implementation may depend on other parts of the - // EP instance. - ExecutionProviders execution_providers_; - // The model served by this inference session instance. // Currently this has to be a shared ptr because the Model::Load method // returns a shared_ptr only. Ideally factory functions should always return @@ -669,6 +637,9 @@ class InferenceSession { // The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx PathString model_location_; + // The list of execution providers. + ExecutionProviders execution_providers_; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); void SetLoggingManager(const SessionOptions& session_options, diff --git a/onnxruntime/core/session/model_editor_api.h b/onnxruntime/core/session/model_editor_api.h deleted file mode 100644 index 71004866bc867..0000000000000 --- a/onnxruntime/core/session/model_editor_api.h +++ /dev/null @@ -1,65 +0,0 @@ -namespace OrtModelEditorAPI { - -// implementation that returns the API struct -ORT_API(const OrtModelEditorApi*, GetModelEditorApi); - -// APIs to create/edit type info -ORT_API_STATUS_IMPL(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, - _Out_ OrtTypeInfo** type_info); -ORT_API_STATUS_IMPL(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, - _Out_ OrtTypeInfo** type_info); -ORT_API_STATUS_IMPL(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, - _Out_ OrtTypeInfo** type_info); -ORT_API_STATUS_IMPL(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Out_ OrtTypeInfo** type_info); -ORT_API_STATUS_IMPL(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Out_ OrtTypeInfo** type_info); - -ORT_API_STATUS_IMPL(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, - _Outptr_ OrtValueInfo** value_info); - -ORT_API_STATUS_IMPL(CreateNode, const char* operator_name, const char* domain_name, _In_ const char* node_name, - _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, - _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, - _In_reads_(attribs_len) _Inout_opt_ OrtOpAttr** attributes, _In_opt_ size_t attribs_len, - _Outptr_ OrtNode** node); - -ORT_API_STATUS_IMPL(CreateGraph, _Outptr_ OrtGraph** graph); -ORT_API_STATUS_IMPL(SetGraphInputs, _In_ OrtGraph* graph, - _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); -ORT_API_STATUS_IMPL(SetGraphOutputs, _In_ OrtGraph* graph, - _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); -ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor, - bool data_is_external); -ORT_API_STATUS_IMPL(AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node); - -ORT_API_STATUS_IMPL(CreateModel, - _In_reads_(opset_entries_len) const char* const* domain_names, - _In_reads_(opset_entries_len) const int* opset_versions, - size_t opset_entries_len, - _Outptr_ OrtModel** model); -ORT_API_STATUS_IMPL(AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph); - -ORT_API_STATUS_IMPL(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, - _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); - -// -// Model editing APIs for updating existing model by adding node/s at start or end. -// -ORT_API_STATUS_IMPL(CreateModelEditorSession, _In_ const OrtEnv* env, - _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, - _Outptr_ OrtSession** out); - -ORT_API_STATUS_IMPL(CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, - _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, - _Outptr_ OrtSession** out); - -ORT_API_STATUS_IMPL(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, - _Out_ int* opset); - -ORT_API_STATUS_IMPL(ApplyModelToModelEditorSession, _In_ OrtSession* session, _In_ OrtModel* model); - -ORT_API_STATUS_IMPL(FinalizeModelEditorSession, _In_ OrtSession* session, _In_ const OrtSessionOptions* options, - _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container); - -} // namespace OrtModelEditorAPI diff --git a/onnxruntime/core/session/model_editor_c_api.cc b/onnxruntime/core/session/model_editor_c_api.cc deleted file mode 100644 index 2f09b903ed941..0000000000000 --- a/onnxruntime/core/session/model_editor_c_api.cc +++ /dev/null @@ -1,358 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#if !defined(ORT_MINIMAL_BUILD) - -#include - -#include "core/framework/error_code_helper.h" -#include "core/framework/ort_value.h" -#include "core/framework/onnxruntime_typeinfo.h" -#include "core/framework/tensor_type_and_shape.h" -#include "core/graph/constants.h" -#include "core/graph/model.h" -#include "core/graph/model_editor_api_types.h" -#include "core/graph/onnx_protobuf.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/inference_session.h" -#include "core/session/model_editor_api.h" -#include "core/session/ort_apis.h" -#include "core/session/ort_env.h" -#include "core/session/utils.h" - -using namespace onnxruntime; - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, - _Outptr_ OrtValueInfo** value_info) { - API_IMPL_BEGIN - if (name == nullptr || *name == '\0') { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "name cannot be null or empty string"); - } - - if (type_info == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_info cannot be null"); - } - - if (type_info->type != ONNX_TYPE_TENSOR) { - return OrtApis::CreateStatus(ORT_FAIL, "Only tensor types are supported currently"); - } - - if (type_info->tensor_type_info == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tensor_type_info cannot be null"); - } - - auto vi = std::make_unique(); - vi->name = name; - vi->type_info = type_info->Clone(); - - *value_info = vi.release(); - - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateNode, const char* operator_name, const char* domain_name, - _In_ const char* node_name, - _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, - _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, - _In_reads_(attribs_len) _Inout_opt_ OrtOpAttr** attributes, _In_opt_ size_t attribs_len, - _Outptr_ OrtNode** node) { - API_IMPL_BEGIN - auto n = std::make_unique(); - n->operator_name = operator_name; - n->domain_name = domain_name == kOnnxDomainAlias ? kOnnxDomain : domain_name; - n->node_name = node_name; - - n->input_names.reserve(input_names_len); - for (size_t i = 0; i < input_names_len; ++i) { - n->input_names.push_back(input_names[i]); - } - - n->output_names.reserve(output_names_len); - for (size_t i = 0; i < output_names_len; ++i) { - n->output_names.push_back(output_names[i]); - } - - if (attributes != nullptr) { - n->attributes.reserve(attribs_len); - for (size_t i = 0; i < attribs_len; ++i) { - n->attributes.push_back(*reinterpret_cast(attributes[i])); - // take ownership. as we took a copy that means releasing the original value - OrtApis::ReleaseOpAttr(attributes[i]); - attributes[i] = nullptr; - } - } - - *node = n.release(); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateGraph, _Outptr_ OrtGraph** graph) { - API_IMPL_BEGIN - auto g = std::make_unique(); - - // do some reserves to reduce reallocation. if we had a hint about sizes upfront that would be optimal - g->initializers.reserve(32); - g->external_initializers.reserve(32); - g->nodes.reserve(64); - - *graph = g.release(); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* graph, - _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len) { - API_IMPL_BEGIN - graph->inputs.clear(); - for (size_t i = 0; i < inputs_len; ++i) { - if (inputs[i] == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "inputs cannot contain null entries"); - } - - graph->inputs.push_back(std::unique_ptr(inputs[i])); // take ownership - inputs[i] = nullptr; - } - - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* graph, - _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len) { - API_IMPL_BEGIN - graph->outputs.clear(); - for (size_t i = 0; i < outputs_len; ++i) { - if (outputs[i] == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "outputs cannot contain null entries"); - } - - graph->outputs.push_back(std::unique_ptr(outputs[i])); // take ownership - outputs[i] = nullptr; - } - - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, - _Inout_ OrtValue* tensor, bool data_is_external) { - API_IMPL_BEGIN - if (!tensor->IsTensor()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only Tensor is currently supported."); - } - - if (!tensor->IsAllocated()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Tensor must be allocated."); - } - - const auto& t = tensor->Get(); - if (t.Location().device.Type() != OrtDevice::CPU) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only CPU based tensors are currently supported."); - } - - if (data_is_external) { - // enforce that an external initializer is not used if the data size is < 128 bytes. - // the reason for this is to avoid potential shape inferencing errors if this initializer is providing an - // input involved in that. the ONNX shape inferencing does not support external data for those values. - // e.g. Reshape's `shape` input, Reduce's `axes', Slice's `starts`, `ends`, `steps`, Clip's `min`, `max`, etc. - if (t.SizeInBytes() < 128) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "External initializer should only be used for data >= 128 bytes. " - "Please use CreateTensorAsOrtValue instead."); - } - - graph->external_initializers[name] = std::unique_ptr(tensor); // take ownership - } else { - graph->initializers[name] = std::unique_ptr(tensor); // take ownership - } - - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node) { - API_IMPL_BEGIN - graph->nodes.push_back(std::unique_ptr(node)); // take ownership - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModel, - _In_reads_(opset_entries_len) const char* const* domain_names, - _In_reads_(opset_entries_len) const int* opset_versions, - size_t opset_entries_len, - _Outptr_ OrtModel** model) { - API_IMPL_BEGIN - auto m = std::make_unique(); - for (size_t i = 0; i < opset_entries_len; ++i) { - m->domain_to_version[domain_names[i]] = opset_versions[i]; - } - - *model = m.release(); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph) { - API_IMPL_BEGIN - - if (graph == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); - } - - model->graph = std::unique_ptr(graph); // take ownership - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, - _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { - API_IMPL_BEGIN - - std::unique_ptr sess; - OrtStatus* status = nullptr; - *out = nullptr; - - ORT_TRY { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment()); - - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(*model)); - - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); - - *out = reinterpret_cast(sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } - - return status; - - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModelEditorSession, - _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, - _Outptr_ OrtSession** out) { - API_IMPL_BEGIN - std::unique_ptr session; - OrtStatus* status = nullptr; - *out = nullptr; - - ORT_TRY { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, session)); - *out = reinterpret_cast(session.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } - - return status; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, - _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, - _Outptr_ OrtSession** out) { - API_IMPL_BEGIN - std::unique_ptr session; - OrtStatus* status = nullptr; - *out = nullptr; - - ORT_TRY { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, session)); - *out = reinterpret_cast(session.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } - - return status; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::SessionGetOpsetForDomain, _In_ const OrtSession* ort_session, - _In_ const char* domain, _Out_ int* opset) { - const auto& session = *reinterpret_cast(ort_session); - const auto& domain_opset_map = session.GetModel().MainGraph().DomainToVersionMap(); - - auto it = domain_opset_map.find(domain); - if (it == domain_opset_map.cend()) { - return OrtApis::CreateStatus(ORT_FAIL, "Domain not used by model."); - } - - *opset = it->second; - return nullptr; -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::ApplyModelToModelEditorSession, - _In_ OrtSession* session, _In_ OrtModel* model) { - API_IMPL_BEGIN - auto sess = reinterpret_cast(session); - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->ApplyUpdates(*model)); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtModelEditorAPI::FinalizeModelEditorSession, _In_ OrtSession* session, - _In_ const OrtSessionOptions* options, - _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container) { - API_IMPL_BEGIN - auto sess = reinterpret_cast(session); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); - return nullptr; - API_IMPL_END -} - -static constexpr OrtModelEditorApi ort_model_editor_api = { - // NOTE: The C# bindings depend on the API order within this struct so all additions must be at the end, - // and no functions can be removed (the implementation needs to change to return an error). - - &OrtModelEditorAPI::CreateTensorTypeInfo, - &OrtModelEditorAPI::CreateSparseTensorTypeInfo, - &OrtModelEditorAPI::CreateMapTypeInfo, - &OrtModelEditorAPI::CreateSequenceTypeInfo, - &OrtModelEditorAPI::CreateOptionalTypeInfo, - - &OrtModelEditorAPI::CreateValueInfo, - - &OrtModelEditorAPI::CreateNode, - - &OrtModelEditorAPI::CreateGraph, - &OrtModelEditorAPI::SetGraphInputs, - &OrtModelEditorAPI::SetGraphOutputs, - &OrtModelEditorAPI::AddInitializerToGraph, - &OrtModelEditorAPI::AddNodeToGraph, - - &OrtModelEditorAPI::CreateModel, - &OrtModelEditorAPI::AddGraphToModel, - - &OrtModelEditorAPI::CreateSessionFromModel, - - &OrtModelEditorAPI::CreateModelEditorSession, - &OrtModelEditorAPI::CreateModelEditorSessionFromArray, - &OrtModelEditorAPI::SessionGetOpsetForDomain, - &OrtModelEditorAPI::ApplyModelToModelEditorSession, - &OrtModelEditorAPI::FinalizeModelEditorSession, -}; - -// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned -static_assert(offsetof(OrtModelEditorApi, FinalizeModelEditorSession) / sizeof(void*) == 19, - "Size of version 21 API cannot change"); // initial version in ORT 1.21 - -ORT_API(const OrtModelEditorApi*, OrtModelEditorAPI::GetModelEditorApi) { - return &ort_model_editor_api; -} - -#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 0e23d7a791bec..4eedcd591154f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1,47 +1,45 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/session/onnxruntime_c_api.h" +#include "core/session/allocator_adapters.h" +#include "core/session/inference_session_utils.h" +#include "core/session/IOBinding.h" +#include "core/framework/allocator.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/execution_provider.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/utils.h" #include #include #include -#include #include #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" -#include "core/common/safeint.h" #include "core/common/status.h" -#include "core/common/string_helper.h" -#include "core/framework/allocator.h" -#include "core/framework/allocator.h" -#include "core/framework/callback.h" -#include "core/framework/data_types.h" -#include "core/framework/error_code_helper.h" -#include "core/framework/execution_provider.h" -#include "core/framework/onnxruntime_typeinfo.h" -#include "core/framework/ort_value.h" -#include "core/framework/tensor.h" -#include "core/framework/tensor_type_and_shape.h" -#include "core/framework/tensorprotoutils.h" -#include "core/framework/TensorSeq.h" -#include "core/framework/utils.h" +#include "core/common/safeint.h" #include "core/graph/constants.h" #include "core/graph/graph.h" -#include "core/graph/model_editor_api_types.h" +#include "core/framework/allocator.h" +#include "core/framework/tensor.h" +#include "core/framework/ort_value.h" #include "core/providers/get_execution_providers.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/allocator_adapters.h" #include "core/session/environment.h" +#include "core/framework/callback.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/onnxruntime_typeinfo.h" #include "core/session/inference_session.h" -#include "core/session/inference_session_utils.h" -#include "core/session/IOBinding.h" -#include "core/session/lora_adapters.h" -#include "core/session/model_editor_api.h" -#include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" -#include "core/session/utils.h" +#include "core/framework/data_types.h" +#include "abi_session_options_impl.h" +#include "core/framework/TensorSeq.h" +#include +#include "core/common/string_helper.h" + +#include "core/session/lora_adapters.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" @@ -116,72 +114,6 @@ using namespace onnxruntime; auto v = (value); \ auto tensor = v->GetMutable(); -namespace { -// Create tensor. Allocates memory. Tensor owns memory. Allocator is wrapped and stored in a shared_ptr in Tensor. -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, - OrtAllocator* allocator, OrtValue& value) { - TensorShape tensor_shape(shape, shape_len); - AllocatorPtr alloc_ptr = std::make_shared(allocator); - Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value); - return nullptr; -} - -// Create Tensor with existing data. Tensor does not own memory. -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, - const int64_t* shape, size_t shape_len, - const OrtMemoryInfo* info, - void* p_data, size_t p_data_len, - OrtValue& ort_value) { - TensorShape tensor_shape(shape, shape_len); - if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); - } - - size_t size_to_allocate = 0; - Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); - if (!status.IsOK()) { - return ToOrtStatus(status); - } - if (size_to_allocate > p_data_len) { - std::ostringstream oss; - oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len; - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); - } - - Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value); - return nullptr; -} - -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, - const int64_t* shape, size_t shape_len, - OrtAllocator* deleter, - void* p_data, size_t p_data_len, - OrtValue& ort_value) { - TensorShape tensor_shape(shape, shape_len); - if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); - } - - size_t size_to_allocate = 0; - Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); - - if (!status.IsOK()) { - return ToOrtStatus(status); - } - - if (size_to_allocate > p_data_len) { - std::ostringstream oss; - oss << "p_data_len was smaller than expected. Expected:" << size_to_allocate << " Got:" << p_data_len; - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); - } - - AllocatorPtr alloc_ptr = std::make_shared(deleter); - Tensor::InitOrtValue(ml_type, tensor_shape, p_data, std::move(alloc_ptr), ort_value); - return nullptr; -} - -} // namespace - ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out) { @@ -255,6 +187,50 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env, API_IMPL_END } +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, + _Inout_ OrtAllocator* allocator, OrtValue& value) { + TensorShape tensor_shape(shape, shape_len); + AllocatorPtr alloc_ptr = std::make_shared(allocator); + Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value); + return nullptr; +} + +ORT_STATUS_PTR CreateTensorImplForSeq(MLDataType elem_type, const int64_t* shape, size_t shape_len, Tensor& out) { + OrtAllocator* allocator; + // TODO(pranav): what allocator should be used to create the tensor here? + // for the sake of simplicity of the API using the default one here + ORT_API_RETURN_IF_ERROR(OrtApis::GetAllocatorWithDefaultOptions(&allocator)); + AllocatorPtr alloc_ptr = std::make_shared(allocator); + TensorShape tensor_shape(shape, shape_len); + out = Tensor(elem_type, tensor_shape, std::move(alloc_ptr)); + return nullptr; +} + +/** + * + * this function will create a copy of the allocator info + */ +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, const OrtMemoryInfo* info, + void* p_data, size_t p_data_len, OrtValue& ort_value) { + TensorShape tensor_shape(shape, shape_len); + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); + } + + size_t size_to_allocate = 0; + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); + if (!status.IsOK()) { + return ToOrtStatus(status); + } + if (size_to_allocate > p_data_len) { + std::ostringstream oss; + oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value); + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { @@ -267,20 +243,6 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemor API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, - _In_ void* p_data, size_t p_data_len, - _In_ const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type, - _Outptr_ OrtValue** out) { - API_IMPL_BEGIN - auto ml_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType(); - auto value = std::make_unique(); - ORT_API_RETURN_IF_ERROR(CreateTensorImpl(ml_type, shape, shape_len, deleter, p_data, p_data_len, *value)); - *out = value.release(); - return nullptr; - API_IMPL_END -} - ORT_API_STATUS_IMPL(OrtApis::CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { @@ -716,6 +678,97 @@ ORT_API_STATUS_IMPL(OrtApis::EnableOrtCustomOps, _Inout_ OrtSessionOptions* opti API_IMPL_END } +namespace { +// provider either model_path, or modal_data + model_data_length. +static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess) { + // quick check here to decide load path. InferenceSession will provide error message for invalid values. + // TODO: Could move to a helper + const Env& os_env = Env::Default(); // OS environment (!= ORT environment) + bool load_config_from_model = + os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1"; + + if (load_config_from_model) { +#if !defined(ORT_MINIMAL_BUILD) + if (model_path != nullptr) { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment(), + model_path); + } else { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment(), + model_data, static_cast(model_data_length)); + } +#else + return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build."); +#endif + } else { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment()); + } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + // Add custom domains + if (options && !options->custom_op_domains_.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); + } +#endif + + // Finish load + if (load_config_from_model) { +#if !defined(ORT_MINIMAL_BUILD) + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load()); +#endif + } else { + if (model_path != nullptr) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path)); + } else { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast(model_data_length))); + } + } + + return nullptr; +} + +static ORT_STATUS_PTR InitializeSession(_In_ const OrtSessionOptions* options, + _In_ std::unique_ptr<::onnxruntime::InferenceSession>& sess, + _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr) { + // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of + // byte addressable memory + std::vector> provider_list; + if (options) { + for (auto& factory : options->provider_factories) { + auto provider = factory->CreateProvider(); + provider_list.push_back(std::move(provider)); + } + } + + // register the providers + for (auto& provider : provider_list) { + if (provider) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->RegisterExecutionProvider(std::move(provider))); + } + } + + if (prepacked_weights_container != nullptr) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddPrePackedWeightsContainer( + reinterpret_cast(prepacked_weights_container))); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Initialize()); + + return nullptr; +} + +} // namespace + ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN @@ -725,7 +778,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); *out = reinterpret_cast(sess.release()); } @@ -748,7 +801,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); *out = reinterpret_cast(sess.release()); } @@ -1155,6 +1208,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetResizedStringTensorElementBuffer, _Inout_ OrtVal } namespace { + OrtStatusPtr GetTensorStringSpan(const ::OrtValue& v, gsl::span& span) { if (!v.IsAllocated()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue should contain a Tensor or a Sparse Tensor"); @@ -2058,6 +2112,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, _In_ const char* domain_name, _In_ } namespace { + struct ProviderBuffer { char** buffer_; char* next_write_; @@ -2287,7 +2342,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionWithPrepackedWeightsContainer, _In_ co ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess, prepacked_weights_container)); *out = reinterpret_cast(sess.release()); } @@ -2313,7 +2368,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess, prepacked_weights_container)); *out = reinterpret_cast(sess.release()); } @@ -2355,39 +2410,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSes API_IMPL_END } -ORT_API(void, OrtApis::ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo* value_info) { - delete value_info; -} - -ORT_API(void, OrtApis::ReleaseNode, _Frees_ptr_opt_ OrtNode* node) { - delete node; -} - -ORT_API(void, OrtApis::ReleaseGraph, _Frees_ptr_opt_ OrtGraph* graph) { - delete graph; -} - -ORT_API(void, OrtApis::ReleaseModel, _Frees_ptr_opt_ OrtModel* model) { - delete model; -} - -ORT_API_STATUS_IMPL(OrtApis::GetValueInfoName, _In_ const OrtValueInfo* value_info, - _Out_ const char** name) { - API_IMPL_BEGIN - *name = value_info->name.c_str(); - return nullptr; - API_IMPL_END -} -ORT_API_STATUS_IMPL(OrtApis::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, - _Outptr_ const OrtTypeInfo** type_info) { - API_IMPL_BEGIN - - *type_info = value_info->type_info.get(); - - return nullptr; - API_IMPL_END -} - ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #ifdef ENABLE_TRAINING_APIS if (version >= 13 && version <= ORT_API_VERSION) @@ -2397,17 +2419,9 @@ ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { version, ORT_API_VERSION); return nullptr; #else - ORT_UNUSED_PARAMETER(version); - return nullptr; -#endif -} + ORT_UNUSED_PARAMETER(version); -ORT_API(const OrtModelEditorApi*, OrtApis::GetModelEditorApi) { -#if !defined(ORT_MINIMAL_BUILD) - return OrtModelEditorAPI::GetModelEditorApi(); -#else - fprintf(stderr, "The Model Editor API is not supported in a minimal build.\n"); return nullptr; #endif } @@ -2798,18 +2812,6 @@ static constexpr OrtApi ort_api_1_to_22 = { &OrtApis::SetEpDynamicOptions, // End of Version 20 - DO NOT MODIFY ABOVE (see above text for more information) - - &OrtApis::ReleaseValueInfo, - &OrtApis::ReleaseNode, - &OrtApis::ReleaseGraph, - &OrtApis::ReleaseModel, - - &OrtApis::GetValueInfoName, - &OrtApis::GetValueInfoTypeInfo, - - &OrtApis::GetModelEditorApi, - - &OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9d8aeb18a782f..52d3c98d526dc 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -20,10 +20,6 @@ ORT_API(void, ReleaseCustomOpDomain, _Frees_ptr_opt_ OrtCustomOpDomain*); ORT_API(void, ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo*); ORT_API(void, ReleaseSequenceTypeInfo, _Frees_ptr_opt_ OrtSequenceTypeInfo*); ORT_API(void, ReleaseModelMetadata, _Frees_ptr_opt_ OrtModelMetadata*); -ORT_API(void, ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo*); -ORT_API(void, ReleaseNode, _Frees_ptr_opt_ OrtNode*); -ORT_API(void, ReleaseGraph, _Frees_ptr_opt_ OrtGraph*); -ORT_API(void, ReleaseModel, _Frees_ptr_opt_ OrtModel*); _Check_return_ _Ret_notnull_ [[nodiscard]] OrtStatus* ORT_API_CALL CreateStatus(OrtErrorCode code, _In_z_ const char* msg) NO_EXCEPTION; @@ -537,16 +533,4 @@ ORT_API_STATUS_IMPL(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* optio ORT_API_STATUS_IMPL(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); - -ORT_API_STATUS_IMPL(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); -ORT_API_STATUS_IMPL(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); - -ORT_API(const OrtModelEditorApi*, GetModelEditorApi); - -ORT_API_STATUS_IMPL(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, - _In_ void* p_data, size_t p_data_len, - _In_ const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type, - _Outptr_ OrtValue** out); - } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 2ea4a93d21f2e..77c6d4c371f69 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -4,7 +4,6 @@ // This is the Onnxruntime side of the bridge to allow providers to be built as a DLL // It implements onnxruntime::ProviderHost -#include #include "core/common/inlined_containers.h" #include "core/common/path_string.h" #include "core/framework/allocator_utils.h" @@ -36,7 +35,6 @@ #include "core/graph/graph_proto_serializer.h" #include "core/framework/murmurhash3.h" #include "core/framework/model_metadef_id_generator.h" -#include "core/optimizer/graph_optimizer_registry.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" @@ -239,21 +237,6 @@ common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_na struct ProviderHostImpl : ProviderHost { const OrtApiBase* OrtGetApiBase() override { return ::OrtGetApiBase(); } - Status GetOptimizerByName(const std::string& name, - const GraphOptimizerRegistry& graph_optimizer_registry, - SelectionFunc& selection_func) override { - std::string optimizer_name(name); - - auto func = graph_optimizer_registry.GetSelectionFunc(optimizer_name); - - if (func.has_value()) { - selection_func = func.value(); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get optimizer " + optimizer_name); - } - return Status::OK(); - }; - void* HeapAllocate(size_t size) override { return new uint8_t[size]; } void HeapFree(void* p) override { delete[] reinterpret_cast(p); } @@ -377,9 +360,8 @@ struct ProviderHostImpl : ProviderHost { std::vector> IExecutionProvider__GetCapability( const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, const IExecutionProvider::IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) override { - return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, resource_accountant); + return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup, resource_accountant); } common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override { @@ -815,8 +797,6 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) override { return std::make_unique(std::move(t_sub_graph)); } void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; } std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability* p) override { return p->sub_graph; } - void ComputeCapability__copy_optimization_func(ComputeCapability* p, ComputeCapability* selection_cc) override { p->optimization_func = selection_cc->optimization_func; } - void ComputeCapability__add_nodes_to_optimize(ComputeCapability* p, std::unique_ptr optimization_cc) override { p->nodes_to_optimize.push_back(std::move(optimization_cc)); } // DataTransferManager (wrapped) Status DataTransferManager__CopyTensor(const DataTransferManager* p, const Tensor& src, Tensor& dst) override { return p->CopyTensor(src, dst); } @@ -1651,7 +1631,6 @@ struct ProviderHostImpl : ProviderHost { Status LoadDynamicLibrary(onnxruntime::PathString library_name) override { return LoadDynamicLibraryFromProvider(library_name); }; #endif } provider_host_; - #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc deleted file mode 100644 index afb1ed2696c9f..0000000000000 --- a/onnxruntime/core/session/utils.cc +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/utils.h" - -#include "core/framework/error_code_helper.h" -#include "core/framework/execution_provider.h" -#include "core/session/abi_session_options_impl.h" -// #include "core/session/environment.h" -#include "core/session/inference_session.h" -#include "core/session/inference_session_utils.h" -#include "core/session/onnxruntime_c_api.h" -#include "core/session/ort_apis.h" -#include "core/session/ort_env.h" - -using namespace onnxruntime; - -common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { - const size_t str_len = str.size(); - const size_t req_size = str_len + 1; - - if (out == nullptr) { // User is querying the total output buffer size - *size = req_size; - return onnxruntime::common::Status::OK(); - } - - if (*size >= req_size) { // User provided a buffer of sufficient size - std::memcpy(out, str.data(), str_len); - out[str_len] = '\0'; - *size = req_size; - return onnxruntime::common::Status::OK(); - } - - // User has provided a buffer that is not large enough - *size = req_size; - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); -} - -// provider either model_path, or modal_data + model_data_length. -OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, - _In_ const OrtEnv* env, - _In_opt_z_ const ORTCHAR_T* model_path, - _In_opt_ const void* model_data, - size_t model_data_length, - std::unique_ptr& sess) { - // quick check here to decide load path. InferenceSession will provide error message for invalid values. - // TODO: Could move to a helper - const Env& os_env = Env::Default(); // OS environment (!= ORT environment) - bool load_config_from_model = - os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1"; - - if (load_config_from_model) { -#if !defined(ORT_MINIMAL_BUILD) - if (model_path != nullptr) { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), - model_path); - } else { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), - model_data, static_cast(model_data_length)); - } -#else - return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build."); -#endif - } else { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment()); - } - -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - // Add custom domains - if (options && !options->custom_op_domains_.empty()) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); - } -#endif - - // Finish load - if (load_config_from_model) { -#if !defined(ORT_MINIMAL_BUILD) - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load()); -#endif - } else { - if (model_path != nullptr) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path)); - } else { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast(model_data_length))); - } - } - - return nullptr; -} - -OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, - _In_ onnxruntime::InferenceSession& sess, - _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) { - // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of - // byte addressable memory - std::vector> provider_list; - if (options) { - for (auto& factory : options->provider_factories) { - auto provider = factory->CreateProvider(); - provider_list.push_back(std::move(provider)); - } - } - - // register the providers - for (auto& provider : provider_list) { - if (provider) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider))); - } - } - - if (prepacked_weights_container != nullptr) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess.AddPrePackedWeightsContainer( - reinterpret_cast(prepacked_weights_container))); - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); - - return nullptr; -} diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h deleted file mode 100644 index ac8ad60758b5b..0000000000000 --- a/onnxruntime/core/session/utils.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "core/common/common.h" -#include "core/session/onnxruntime_c_api.h" - -onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size); - -struct OrtSessionOptions; -struct OrtStatus; -struct OrtPrepackedWeightsContainer; -namespace onnxruntime { -class InferenceSession; -} - -OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, - _In_ const OrtEnv* env, - _In_opt_z_ const ORTCHAR_T* model_path, - _In_opt_ const void* model_data, - size_t model_data_length, - std::unique_ptr& sess); - -OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, - _In_ onnxruntime::InferenceSession& sess, - _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 50da0025752aa..ea995d4707ba3 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -204,9 +204,9 @@ def get_qnn_qdq_config( calibrate_method=calibrate_method, activation_type=activation_type, weight_type=weight_type, - op_types_to_quantize=( - op_types_to_quantize if op_types_to_quantize else list(op_types.difference(OP_TYPES_TO_EXCLUDE)) - ), + op_types_to_quantize=op_types_to_quantize + if op_types_to_quantize + else list(op_types.difference(OP_TYPES_TO_EXCLUDE)), nodes_to_exclude=nodes_to_exclude, per_channel=per_channel, use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index d19bebad8a12c..fa468a9676a65 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -240,8 +240,6 @@ def get_qdq_config( keep_removable_activations: bool = False, min_real_range: float | None = None, tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None, - calibration_providers: list[str] | None = None, - op_types_to_quantize: list[str] | None = None, nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None, extra_options: dict | None = None, ) -> StaticQuantConfig: @@ -296,10 +294,6 @@ def get_qdq_config( 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation, other nodes get the original type. If not specified, assume all consumer nodes get the converted type. - calibration_providers: Execution providers to run the session during calibration. Default is None which uses - [ "CPUExecutionProvider" ]. - op_types_to_quantize: List of operator types to quantize. If None, all operators other than Cast, DequantizeLinear, - and QuantizeLinear are quantized. nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto should be excluded from quantization. @@ -330,20 +324,17 @@ def get_qdq_config( if onnx.external_data_helper.uses_external_data(initializer): model_has_external_data = True - op_types_to_quantize_set = set(op_types_to_quantize) if op_types_to_quantize else None - nodes_to_exclude_set = set(nodes_to_exclude) if isinstance(nodes_to_exclude, list) else set() + final_nodes_to_exclude = [] + if nodes_to_exclude is not None and isinstance(nodes_to_exclude, list): + final_nodes_to_exclude.extend(nodes_to_exclude) # Iterate through nodes to get all operator types in the model and # call user's function to filter out nodes from quantization. for node in model.graph.node: - if op_types_to_quantize_set and node.op_type not in op_types_to_quantize_set: - continue - if node.name in nodes_to_exclude_set: - continue - if callable(nodes_to_exclude) and nodes_to_exclude(model, node): - nodes_to_exclude_set.add(node.name) - else: - op_types.add(node.op_type) + op_types.add(node.op_type) + if nodes_to_exclude is not None and callable(nodes_to_exclude): + if nodes_to_exclude(model, node): + final_nodes_to_exclude.append(node.name) final_extra_options = { "MinimumRealRange": min_real_range, @@ -387,14 +378,11 @@ def get_qdq_config( quant_format=QuantFormat.QDQ, activation_type=activation_type, weight_type=weight_type, - op_types_to_quantize=( - op_types_to_quantize if op_types_to_quantize else list(op_types.difference(op_types_to_exclude)) - ), - nodes_to_exclude=list(nodes_to_exclude_set), + op_types_to_quantize=list(op_types.difference(op_types_to_exclude)), + nodes_to_exclude=final_nodes_to_exclude, per_channel=per_channel, reduce_range=reduce_range, use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), - calibration_providers=calibration_providers, extra_options=final_extra_options, ) @@ -454,7 +442,7 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua if activation_type != QuantType.QFLOAT8E4M3FN and weight_type == QuantType.QFLOAT8E4M3FN: raise ValueError( f"ONNXRuntime quantization doesn't support data format: activation_type={activation_type} " - "!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN." + f"!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN." ) if activation_type == QuantType.QFLOAT8E4M3FN and weight_type != QuantType.QFLOAT8E4M3FN: diff --git a/onnxruntime/python/tools/transformers/models/sam2/README.md b/onnxruntime/python/tools/transformers/models/sam2/README.md index 463d154525f8f..e7cafeffc6231 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/README.md +++ b/onnxruntime/python/tools/transformers/models/sam2/README.md @@ -96,7 +96,8 @@ We can create a conda environment then run GPU benchmark like the following: conda create -n sam2_gpu python=3.11 -y conda activate sam2_gpu install_dir=$HOME -bash benchmark_sam2.sh $install_dir gpu +profiling=true +bash benchmark_sam2.sh $install_dir gpu $profiling ``` or create a new conda environment for CPU benchmark: @@ -106,28 +107,16 @@ conda activate sam2_cpu bash benchmark_sam2.sh $HOME cpu ``` -The usage of the script like the following: -``` -bash benchmark_sam2.sh [profiling] [benchmarking] [nightly] [dynamo] -``` - -| Parameter| Default | Description | -|----------|----------| ------------| -| install_dir | $HOME | a directory to clone git repositories or install CUDA/cuDNN for benchmark | -| cpu_or_gpu | gpu | the device to run benchmark. The value can be either "gpu" or "cpu" | -| profiling | false | run gpu profiling | -| benchmarking | true | run benchmark | -| nightly | false | install onnxruntime nightly or official release package | -| dynamo | false | export image encoder using dynamo or not. | +The first parameter is a directory to clone git repositories or install CUDA/cuDNN for benchmark. +The second parameter can be either "gpu" or "cpu", which indicates the device to run benchmark. +The third parameter is optional. Value "true" will enable profiling after running benchmarking on GPU. -The dynamo export is experimental since graph optimization still need extra works for this model. +The script will automatically install required packages in current conda environment, download checkpoints, export onnx, +and run demo, benchmark and optionally run profiling. -Output files: -* sam2_cpu_[timestamp].csv or sam2_gpu_[timestamp].csv has benchmark results. Use Excel to load the file to view it. -* onnxruntime_image_[encoder|decoder].json has ONNX Runtime profiling results. Use `chrome://tracing` in Chrome browser to view it. -* torch_image_[encoder|decoder].json has PyTorch profiling results. Use `chrome://tracing` in Chrome browser to view it. -* sam2_fp16_profile_image_[encoder|decoder]_[ort|torch]_gpu.[nsys-rep|sqlite] has NVTX profiling. Use Nvidia NSight System to view it. -* torch_image_encoder_compiled_code.txt has the compiled kernel code from Pytorch. +* The performance test result is in sam2_gpu.csv or sam2_cpu.csv, which can be loaded into Excel. +* The demo output is sam2_demo_fp16_gpu.png or sam2_demo_fp32_cpu.png. +* The profiling results are in *.nsys-rep or *.json files in current directory. Use Nvidia NSight System to view the *.nsys-rep file. ## Limitations - The exported image_decoder model does not support batch mode for now. diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py index 3fc24d157b0cf..16d71d5057b02 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py @@ -46,7 +46,6 @@ def __init__( prefer_nhwc: bool = False, warm_up: int = 5, enable_nvtx_profile: bool = False, - enable_ort_profile: bool = False, enable_torch_profile: bool = False, repeats: int = 1000, verbose: bool = False, @@ -75,7 +74,6 @@ def __init__( self.prefer_nhwc = prefer_nhwc self.warm_up = warm_up self.enable_nvtx_profile = enable_nvtx_profile - self.enable_ort_profile = enable_ort_profile self.enable_torch_profile = enable_torch_profile self.repeats = repeats self.verbose = verbose @@ -319,7 +317,6 @@ def run_test( repeats=args.repeats, warm_up=args.warm_up, enable_nvtx_profile=args.enable_nvtx_profile, - enable_ort_profile=args.enable_ort_profile, enable_torch_profile=args.enable_torch_profile, torch_compile_mode=args.torch_compile_mode, verbose=False, @@ -328,7 +325,7 @@ def run_test( if args.engine == "ort": sess_options = SessionOptions() sess_options.intra_op_num_threads = args.intra_op_num_threads - if config.enable_ort_profile: + if config.enable_nvtx_profile: sess_options.enable_profiling = True sess_options.log_severity_level = 4 sess_options.log_verbosity_level = 0 @@ -352,8 +349,6 @@ def run_test( with nvtx.annotate("one_run"): _ = session.infer(input_dict) cudart.cudaProfilerStop() - - if config.enable_ort_profile: session.ort_session.end_profiling() if repeats == 0: @@ -559,14 +554,6 @@ def _parse_arguments(): help="Enable nvtx profiling. It will add an extra run for profiling before performance test.", ) - parser.add_argument( - "--enable_ort_profile", - required=False, - default=False, - action="store_true", - help="Enable ORT profiling.", - ) - parser.add_argument( "--enable_torch_profile", required=False, diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh index c82b1ed31796e..9e97867657ab9 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh @@ -5,17 +5,7 @@ # ------------------------------------------------------------------------- # Please refer to README.md for the prerequisites and usage of this script. -# bash benchmark_sam2.sh [profiling] [benchmarking] [nightly] [dynamo] -# Note that dynamo need onnxruntime 1.21 or later, or nightly build. -# Example: -# bash benchmark_sam2.sh $HOME gpu true true true false - -install_dir="${1:-$HOME}" -cpu_or_gpu="${2:-gpu}" -profiling="${3:-false}" -benchmarking="${4:-true}" -nightly="${5:-false}" -dynamo="${6:-false}" +# bash benchmark_sam2.sh [profiling] python="$CONDA_PREFIX/bin/python3" @@ -23,6 +13,9 @@ python="$CONDA_PREFIX/bin/python3" dir="$(cd "$(dirname "$0")" && pwd)" onnx_dir="$dir/sam2_onnx_models" +# Installation directory (default: $HOME) +install_dir="${1:-$HOME}" + if [ ! -d "$install_dir" ]; then echo "Error: install_dir '$install_dir' does not exist." exit 1 @@ -33,6 +26,7 @@ sam2_dir="$install_dir/segment-anything-2" model="sam2_hiera_large" # Default to GPU, switch to CPU if specified +cpu_or_gpu="${2:-gpu}" if [ "$cpu_or_gpu" != "gpu" ] && [ "$cpu_or_gpu" != "cpu" ]; then echo "Invalid option: $2. Please specify 'cpu' or 'gpu'." exit 1 @@ -41,97 +35,52 @@ fi echo "install_dir: $install_dir" echo "cpu_or_gpu: $cpu_or_gpu" -# Function to check if a command exists -command_exists() { - command -v "$1" >/dev/null 2>&1 -} - -# Ensure necessary tools are installed -if ! command_exists wget; then - echo "wget is not installed. Please install it and try again." - exit 1 -fi - -if ! command_exists git; then - echo "git is not installed. Please install it and try again." - exit 1 -fi - -if ! command_exists pip; then - echo "pip is not installed. Please install it and try again." - exit 1 -fi - -cuda_version=12.6 -cudnn_version=9.5 +install_cuda_12() +{ + pushd $install_dir + wget https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run + sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath=$install_dir/cuda12.6 --silent --override --no-man-page -# Install CUDA 12.6 -install_cuda_12() { - if ! [ -d "$install_dir/cuda${cuda_version}" ]; then - pushd "$install_dir" || exit - wget https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run - sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath="$install_dir/cuda${cuda_version}" --silent --override --no-man-page - popd || exit - fi - export PATH="$install_dir/cuda${cuda_version}/bin:$PATH" - export LD_LIBRARY_PATH="$install_dir/cuda${cuda_version}/lib64:$LD_LIBRARY_PATH" + export PATH="$install_dir/cuda12.6/bin:$PATH" + export LD_LIBRARY_PATH="$install_dir/cuda12.6/lib64:$LD_LIBRARY_PATH" + popd } -# Install cuDNN 9.5 +# Function to install cuDNN 9.4 install_cudnn_9() { - if ! [ -d "$install_dir/cudnn${cudnn_version}" ]; then - pushd "$install_dir" || exit - wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz - mkdir -p "$install_dir/cudnn${cudnn_version}" - tar -Jxvf cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz -C "$install_dir/cudnn${cudnn_version}" --strip=1 - popd || exit - fi - export LD_LIBRARY_PATH="$install_dir/cudnn${cudnn_version}/lib:$LD_LIBRARY_PATH" -} - -install_ort() { - local ort="$1" - pip uninstall onnxruntime onnxruntime-gpu -y - - if [ "$nightly" = "true" ]; then - pip install flatbuffers numpy packaging protobuf sympy - pip install --pre --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ "$ort" - else - pip install "$ort" - fi - - pip install onnx onnxscript opencv-python matplotlib + pushd "$install_dir" + wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz + mkdir -p "$install_dir/cudnn9.5" + tar -Jxvf cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz -C "$install_dir/cudnn9.5" --strip=1 + export LD_LIBRARY_PATH="$install_dir/cudnn9.5/lib:$LD_LIBRARY_PATH" + popd } # Install GPU dependencies install_gpu() { - install_cuda_12 - install_cudnn_9 - echo "PATH: $PATH" - echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" - - # The dynamo export need torch 2.6.0 or later. Use the latest one. - pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 --upgrade + [ ! -d "$install_dir/cuda12.6" ] && install_cuda_12 + [ ! -d "$install_dir/cudnn9.5" ] && install_cudnn_9 - install_ort "onnxruntime-gpu" + pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 + pip install onnxruntime-gpu onnx opencv-python matplotlib } # Install CPU dependencies install_cpu() { pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - install_ort "onnxruntime" + pip install onnxruntime onnx opencv-python matplotlib } # Clone and install SAM2 if not already installed install_sam2() { - pushd "$install_dir" || exit + pushd "$install_dir" if [ ! -d "$sam2_dir" ]; then git clone https://github.com/facebookresearch/segment-anything-2.git fi - cd "$sam2_dir" || exit + cd "$sam2_dir" pip show SAM-2 > /dev/null 2>&1 || pip install -e . [ ! -f checkpoints/sam2_hiera_large.pt ] && (cd checkpoints && sh ./download_ckpts.sh) - popd || exit + popd } # Download test image if not available @@ -141,12 +90,7 @@ download_test_image() { run_cpu_benchmark() { local repeats="$1" - - if [ "$dynamo" = "true" ]; then - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --demo --dynamo - else - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --demo - fi + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --demo for component in image_encoder image_decoder; do $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --dtype fp32 --component "$component" @@ -159,75 +103,65 @@ run_cpu_benchmark() { done } -run_ort_gpu_benchmark() { +run_gpu_benchmark() { local repeats="$1" + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp32 + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp16 --demo - if [ "$dynamo" = "true" ]; then - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp32 --dynamo - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp16 --demo --dynamo - else - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp32 - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp16 --demo - fi + for component in image_encoder image_decoder; do + for dtype in bf16 fp32 fp16; do + $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" + done + done component="image_encoder" for dtype in fp32 fp16; do - $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype "$dtype" --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --use_cuda_graph + #TODO: --prefer_nhwc does not help with performance + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --use_cuda_graph done - # Test prefer_nhwc. - $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --use_cuda_graph --prefer_nhwc component="image_decoder" for dtype in fp32 fp16; do # TODO: decoder does not work with cuda graph - $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype "$dtype" --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" done - # Test prefer_nhwc. - $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --prefer_nhwc } -run_torch_gpu_benchmark() { +run_torch_compile_gpu_benchmark() { local repeats="$1" - # Test PyTorch eager mode. - for component in image_encoder image_decoder; do - for dtype in bf16 fp32 fp16; do - $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype "$dtype" --component "$component" - done - done - # Test different torch compile modes on image encoder for torch_compile_mode in none max-autotune reduce-overhead max-autotune-no-cudagraphs do - $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component image_encoder --torch_compile_mode $torch_compile_mode + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component image_encoder --torch_compile_mode $torch_compile_mode done } -install_all() { - if [ "$cpu_or_gpu" = "gpu" ]; then - install_gpu - else - install_cpu + +# Main script +run_benchmarks() { + if [ ! -v CONDA_PREFIX ]; then + echo "Please activate conda environment before running this script." + exit 1 fi + + # Install dependencies + [ "$cpu_or_gpu" = "gpu" ] && install_gpu || install_cpu install_sam2 download_test_image -} -run_benchmarks() { - suffix=$(date +"%Y_%m_%d_%H_%M_%S") - [ "$dynamo" = "true" ] && suffix="${suffix}_dynamo" - output_csv="sam2_${cpu_or_gpu}_${suffix}.csv" + # Run benchmarks + output_csv="sam2_${cpu_or_gpu}.csv" if [ ! -f "$output_csv" ]; then echo "Running $cpu_or_gpu benchmark..." if [ "$cpu_or_gpu" = "gpu" ]; then - run_ort_gpu_benchmark 1000 - run_torch_gpu_benchmark 1000 + run_gpu_benchmark 1000 + run_torch_compile_gpu_benchmark 1000 else run_cpu_benchmark 100 fi cat benchmark*.csv > combined_csv awk '!x[$0]++' combined_csv > "$output_csv" - rm benchmark*.csv rm combined_csv echo "Benchmark results saved in $output_csv" else @@ -235,16 +169,7 @@ run_benchmarks() { fi } -if [ ! -v CONDA_PREFIX ]; then - echo "Please activate conda environment before running this script." - exit 1 -fi - -install_all - -if [ "$benchmarking" = "true" ]; then - run_benchmarks -fi +run_benchmarks #-------------------------------------------------------------------------- # Below are for profiling @@ -252,100 +177,79 @@ fi # Build onnxruntime-gpu from source for profiling build_onnxruntime_gpu_for_profiling() { - pushd "$install_dir" || exit + pushd "$install_dir" if ! [ -d onnxruntime ]; then git clone https://github.com/microsoft/onnxruntime fi - cd onnxruntime || exit - pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy - build_dir=build/cuda${cuda_version} - rm -rf ${build_dir}/Release/dist - sh build.sh --config Release --build_dir "${build_dir}" --build_shared_lib --parallel \ - --use_cuda --cuda_version ${cuda_version} --cuda_home "$install_dir/cuda${cuda_version}" \ - --cudnn_home "$install_dir/cudnn${cudnn_version}" \ - --build_wheel --skip_tests \ - --cmake_generator Ninja \ - --compile_no_warning_as_error \ - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=native \ - --cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \ - --enable_cuda_line_info - pip uninstall onnxruntime-gpu -y - pip install "${build_dir}/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl" - popd || exit + cd onnxruntime + CUDA_ARCH=$(python3 -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}{cc[1]}')") + if [ -n "$CUDA_ARCH" ]; then + pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy==1.26.4 + sh build.sh --config Release --build_dir build/cuda12 --build_shared_lib --parallel \ + --use_cuda --cuda_version 12.6 --cuda_home $install_dir/cuda12.6 \ + --cudnn_home $install_dir/cudnn9.5 \ + --build_wheel --skip_tests \ + --cmake_generator Ninja \ + --compile_no_warning_as_error \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=$CUDA_ARCH \ + --cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \ + --enable_cuda_line_info + + pip install build/cuda12/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl numpy==1.26.4 + else + echo "No CUDA device found." + exit 1 + fi + popd } # Run profiling with NVTX. -run_nvtx_profile() { - local engine="$1" +run_nvtx_profile() +{ + pip install nvtx cuda-python==12.6.0 + # Only trace one device to avoid huge output file size. device_id=0 - envs="CUDA_VISIBLE_DEVICES=$device_id,ORT_ENABLE_CUDNN_FLASH_ATTENTION=1,LD_LIBRARY_PATH=$LD_LIBRARY_PATH,TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1" + envs="CUDA_VISIBLE_DEVICES=$device_id,ORT_ENABLE_CUDNN_FLASH_ATTENTION=1,LD_LIBRARY_PATH=$LD_LIBRARY_PATH" cuda_graph_trace=node - for component in image_encoder image_decoder; do - sudo "$install_dir/cuda${cuda_version}/bin/nsys" profile --capture-range=nvtx --nvtx-capture='one_run' \ - --gpu-metrics-devices $device_id --force-overwrite true \ - --sample process-tree --backtrace fp --stats true \ - -t cuda,cudnn,cublas,osrt,nvtx --cuda-memory-usage true --cudabacktrace all \ - --cuda-graph-trace "$cuda_graph_trace" \ - -e "$envs,NSYS_NVTX_PROFILER_REGISTER_ONLY=0" \ - -o "sam2_fp16_profile_${component}_${engine}_${cpu_or_gpu}" \ - $python benchmark_sam2.py --model_type "$model" --engine "$engine" \ - --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ - --onnx_path "${onnx_dir}/${model}_${component}_fp16_gpu.onnx" \ - --component "$component" \ - --use_gpu --dtype fp16 --enable_nvtx_profile - done -} - -run_ort_profile() { - export ORT_ENABLE_CUDNN_FLASH_ATTENTION=1 - rm -f onnxruntime_*.json - for component in image_encoder image_decoder; do - $python benchmark_sam2.py --model_type "$model" --engine ort \ - --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ - --onnx_path "${onnx_dir}/${model}_${component}_fp16_gpu.onnx" \ - --component "$component" \ - --use_gpu --dtype fp16 --enable_ort_profile - mv onnxruntime_profile*.json onnxruntime_$component.json + for engine in ort torch; do + for component in image_encoder image_decoder; do + sudo $install_dir/cuda12.6/bin/nsys profile --capture-range=nvtx --nvtx-capture='one_run' \ + --gpu-metrics-device $device_id --force-overwrite true \ + --sample process-tree --backtrace fp --stats true \ + -t cuda,cudnn,cublas,osrt,nvtx --cuda-memory-usage true --cudabacktrace all \ + --cuda-graph-trace $cuda_graph_trace \ + -e $envs,NSYS_NVTX_PROFILER_REGISTER_ONLY=0 \ + -o sam2_fp16_profile_${component}_${engine}_${cpu_or_gpu} \ + $python benchmark_sam2.py --model_type $model --engine $engine \ + --sam2_dir $sam2_dir --warm_up 1 --repeats 0 \ + --onnx_path ${onnx_dir}/${model}_${component}_fp16_gpu.onnx \ + --component $component \ + --use_gpu --dtype fp16 --enable_nvtx_profile + done done } # Run profiling with PyTorch run_torch_profile() { - # Enable logging might could help get the code of compiled kernels. You can turn it off to reduce overhead. - export TORCH_LOGS="+inductor,+output_code" - export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 - component=image_encoder - $python benchmark_sam2.py --model_type "$model" --engine torch \ - --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ - --component "$component" \ - --torch_compile_mode max-autotune \ - --use_gpu --dtype fp16 --enable_torch_profile > "torch_${component}_compiled_code.txt" - - component=image_decoder - $python benchmark_sam2.py --model_type "$model" --engine torch \ - --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ - --component "$component" \ - --torch_compile_mode none \ - --use_gpu --dtype fp16 --enable_torch_profile + for component in image_encoder image_decoder; do + $python benchmark_sam2.py --model_type $model --engine torch \ + --sam2_dir $sam2_dir --warm_up 1 --repeats 0 \ + --component $component \ + --use_gpu --dtype fp16 --enable_torch_profile + done } -run_nvtx_profilings() { +run_profilings() { build_onnxruntime_gpu_for_profiling + rm -f *.nsys-rep *.sqlite - run_nvtx_profile ort - run_nvtx_profile torch -} + run_nvtx_profile -run_profilings() { - pip install nvtx cuda-python==${cuda_version}.0 - run_ort_profile run_torch_profile - - # NVTX profiling need to build onnxruntime-gpu from source so it is put as the last step. - run_nvtx_profilings } +profiling="${3:-false}" if [ "$profiling" = "true" ] && [ "$cpu_or_gpu" = "gpu" ]; then run_profilings fi diff --git a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py index 3533a274b9972..cacad717faf9c 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py @@ -113,14 +113,6 @@ def parse_arguments(): help="Optimize onnx models for GPU", ) - parser.add_argument( - "--dynamo", - required=False, - default=False, - action="store_true", - help="Use dynamo for exporting onnx model. Only image_encoder supports dynamo right now.", - ) - parser.add_argument( "--verbose", required=False, @@ -159,10 +151,8 @@ def main(): onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output) if component == "image_encoder": if args.overwrite or not os.path.exists(onnx_model_path): - export_image_encoder_onnx( - sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose, args.dynamo - ) - test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=args.dynamic_batch_axes) + export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) + test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=False) elif component == "mask_decoder": if args.overwrite or not os.path.exists(onnx_model_path): diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py index 376e6ba7d802c..07ed150631f50 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py @@ -246,7 +246,7 @@ def test_decoder_onnx( import onnxruntime - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py index 79e9297788c36..c5ce339732063 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py @@ -90,8 +90,6 @@ def export_image_encoder_onnx( onnx_model_path: str, dynamic_batch_axes: bool = False, verbose: bool = False, - dynamo: bool = False, - clear_dynamo_metadata: bool = False, ): image = random_sam2_input_image() @@ -115,65 +113,17 @@ def export_image_encoder_onnx( if not verbose: warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=UserWarning) - - if not dynamo: - torch.onnx.export( - sam2_encoder, - image, - onnx_model_path, - export_params=True, - opset_version=17, - do_constant_folding=True, - input_names=["image"], - output_names=["image_features_0", "image_features_1", "image_embeddings"], - dynamic_axes=dynamic_axes, - ) - else: - torch._dynamo.config.capture_scalar_outputs = True - ep = torch.export.export( - sam2_encoder, - args=(image,), - strict=False, - dynamic_shapes=[ - {0: torch.export.Dim.AUTO}, - ], - ) - - onnx_program = torch.onnx.export( - ep, - (), - opset_version=17, - input_names=["image"], - output_names=["image_features_0", "image_features_1", "image_embeddings"], - dynamo=True, - ) - onnx_program.optimize() - onnx_program.save(onnx_model_path + ".dynamo.onnx", external_data=False) - import onnx - - from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper - - onnx_model = onnx.load_model(onnx_model_path + ".dynamo.onnx", load_external_data=True) - if dynamic_batch_axes: - # Fix labels of dynamic axes since they can't be specified during Dynamo export currently - onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "batch_size" - for i in range(3): - onnx_model.graph.output[i].type.tensor_type.shape.dim[0].dim_param = "batch_size" - - onnx_model_helper = DynamoOnnxHelper(onnx_model) - onnx_model_helper.convert_constants_to_initializers() - if clear_dynamo_metadata: - onnx_model_helper.clear_metadata() - - import os - - if os.path.exists(onnx_model_path): - os.remove(onnx_model_path) - if os.path.exists(onnx_model_path + ".data"): - os.remove(onnx_model_path + ".data") - onnx_model_helper.model.save_model_to_file( - onnx_model_path, use_external_data_format=True, all_tensors_to_one_file=True, convert_attribute=True - ) + torch.onnx.export( + sam2_encoder, + image, + onnx_model_path, + export_params=True, + opset_version=17, + do_constant_folding=True, + input_names=["image"], + output_names=["image_features_0", "image_features_1", "image_embeddings"], + dynamic_axes=dynamic_axes, + ) print("encoder onnx model saved to", onnx_model_path) @@ -183,7 +133,7 @@ def test_image_encoder_onnx( onnx_model_path: str, dynamic_batch_axes=False, ): - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py b/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py index fa83e2f666d06..56473c002d4ae 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py @@ -177,7 +177,7 @@ def test_mask_decoder_onnx( import onnxruntime - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py b/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py index f25e6ff23324b..883c51858346c 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py @@ -146,7 +146,7 @@ def test_prompt_encoder_onnx( import onnxruntime - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc deleted file mode 100644 index 104cdbdfd5abc..0000000000000 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_configuration.h" -#include "command_args_parser.h" - -// onnxruntime dependencies -#include "core/session/onnxruntime_cxx_api.h" -#include "core/session/onnxruntime_session_options_config_keys.h" - -// onnx dependencies -#include "onnx/onnx_pb.h" -#include - -using namespace onnxruntime; -using ProviderOptions = std::unordered_map; - -// from the last context cache Onnx model, find the EPContext node with main_context=1, -// and get the QNN context binary file name, this context binary contains all graphs from all Onnx models -// get the max spill fill buffer size -static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, - std::string& last_ctx_bin_file, - int64_t& max_size) { - max_size = 0; - - onnx::ModelProto model; - std::ifstream onnx_file_stream(last_onnx_ctx_file, std::ios_base::binary); - model.ParseFromIstream(&onnx_file_stream); - - for (auto& node : model.graph().node()) { - if (node.op_type() == "EPContext") { - int64_t is_main_context = 0; - for (auto& attr : node.attribute()) { - if (attr.name() == "main_context") { - is_main_context = attr.i(); - } - if (attr.name() == "max_size") { - max_size = attr.i(); - } - if (attr.name() == "ep_cache_context") { - last_ctx_bin_file = attr.s(); - } - } - if (is_main_context) { - return; - } - } - } - - onnx_file_stream.close(); -} - -// Update generated context cache Onnx model to make the main EPContext node point to -// the last QNN context binary file -// Remove not used QNN context binary file, only keep the last one which contains all graphs -static void UpdateEpContextModel(const std::vector>& ep_ctx_files, - const std::string& last_qnn_ctx_binary_file_name, - int64_t max_size) { - for (auto ep_ctx_file : ep_ctx_files) { - onnx::ModelProto model; - std::ifstream onnx_file_stream(ep_ctx_file, std::ios_base::binary); - model.ParseFromIstream(&onnx_file_stream); - onnx_file_stream.close(); - - for (auto& node : *(model.mutable_graph()->mutable_node())) { - if (node.op_type() == "EPContext") { - int64_t is_main_context = 0; - std::string old_qnn_ctx_binary_file_name; - int max_size_index = 0; - int ep_context_index = 0; - for (auto i = 0; i < node.attribute_size(); ++i) { - auto& attr = node.attribute()[i]; - if (attr.name() == "main_context") { - is_main_context = attr.i(); - } - if (attr.name() == "max_size") { - max_size = attr.i(); - max_size_index = i; - } - if (attr.name() == "ep_cache_context") { - old_qnn_ctx_binary_file_name = attr.s(); - ep_context_index = 0; - } - } - if (is_main_context) { - auto path_str = ToPathString(ep_ctx_file); - auto path = std::filesystem::path(path_str); - auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); - std::remove(file_path.string().c_str()); - - node.mutable_attribute(max_size_index)->set_i(max_size); - node.mutable_attribute(ep_context_index)->set_s(last_qnn_ctx_binary_file_name); - } - } - } - - // re-write the onnx ctx file - std::ofstream onnx_file_ostream(ep_ctx_file, std::ios_base::binary); - model.SerializeToOstream(&onnx_file_ostream); - onnx_file_ostream.close(); - } -} - -#ifdef _WIN32 -int real_main(int argc, wchar_t* argv[]) { -#else -int real_main(int argc, char* argv[]) { -#endif - qnnctxgen::TestConfig test_config; - if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { - qnnctxgen::CommandLineParser::ShowUsage(); - return -1; - } - - OrtLoggingLevel logging_level = test_config.run_config.f_verbose - ? ORT_LOGGING_LEVEL_VERBOSE - : ORT_LOGGING_LEVEL_ERROR; - Ort::Env env(logging_level, "ep_weight_sharing"); - - ORT_TRY { - Ort::SessionOptions so; - so.SetLogId("ep_weight_sharing_ctx_gen_session_logger"); - // Set default session option to dump EPContext model with non-embed mode - so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); - // enable ep.share_ep_contexts - so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); - - ProviderOptions provider_options; - - for (auto it : test_config.run_config.provider_options) { - provider_options[it.first] = it.second; - } - - for (auto it : test_config.run_config.session_config_entries) { - if (it.first == kOrtSessionOptionEpContextEnable && it.second != "1") { - std::cerr << "Need to enable ep context cache." << std::endl; - continue; - } - if (it.first == kOrtSessionOptionEpContextEmbedMode && it.second != "0") { - std::cerr << "Only support non-embed model for weight sharing." << std::endl; - continue; - } - if (it.first == kOrtSessionOptionEpContextFilePath) { - std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; - continue; - } - so.AddConfigEntry(it.first.c_str(), it.second.c_str()); - } - - for (auto model_path : test_config.model_file_paths) { - std::cout << "Model file path: " << ToUTF8String(model_path) << std::endl; - } - - // Generate context cache model files with QNN context binary files - // The context binary file generated later includes all graphs from previous models - { - std::string provider_name_ = test_config.machine_config.provider_type_name; - if (provider_name_ == onnxruntime::kQnnExecutionProvider) { -#ifdef USE_QNN -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - // set default QNN EP option to enable weight sharing if not set by user - const std::string enable_htp_weight_sharing = "enable_htp_weight_sharing"; - if (provider_options.find(enable_htp_weight_sharing) == provider_options.end()) { - provider_options[enable_htp_weight_sharing] = "1"; - } - so.AppendExecutionProvider("QNN", provider_options); -#else - ORT_THROW("QNN is not supported in this build\n"); -#endif - } else if (!provider_name_.empty()) { - ORT_THROW("This execution provider is not included in this tool.\n"); - } - - size_t total_file_count = test_config.model_file_paths.size(); - for (size_t i = 0; i < total_file_count; ++i) { - auto model_path = test_config.model_file_paths[i]; - std::cout << "Generating context cache model for: " << ToUTF8String(model_path) << std::endl; - if (i == total_file_count - 1) { - so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); - } - Ort::Session session(env, model_path.c_str(), so); - } - } - - std::cout << "Start to update the generated Onnx model." << std::endl; - std::vector> ep_ctx_files; - ep_ctx_files.reserve(test_config.model_file_paths.size()); - for (auto model_path : test_config.model_file_paths) { - auto pos = model_path.find_last_of(ORT_TSTR(".")); - if (pos != std::string::npos) { - model_path = model_path.substr(0, pos) + ORT_TSTR("_ctx.onnx"); - } else { - model_path = model_path + ORT_TSTR("_ctx.onnx"); - } - ep_ctx_files.push_back(model_path); - } - - // Get the last context binary file name - std::string last_qnn_ctx_binary_file_name; - int64_t max_size = 0; - GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); - std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; - if (last_qnn_ctx_binary_file_name.empty()) { - throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); - } - ep_ctx_files.pop_back(); - - // Update generated context cache Onnx model to make the main EPContext node point to - // the last QNN context binary file - // Remove not used QNN context binary file, only keep the last one only which contains all graphs - UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); - } - ORT_CATCH(const Ort::Exception& e) { - std::cerr << "Failed to generate context cache file: " << e.what(); - return -1; - } - - std::cout << "Generation done!"; - return 0; -} - -#ifdef _WIN32 -int wmain(int argc, wchar_t* argv[]) { -#else -int main(int argc, char* argv[]) { -#endif - int retval = -1; - ORT_TRY { - retval = real_main(argc, argv); - } - ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "%s\n", ex.what()); - retval = -1; - }); - } - - ::google::protobuf::ShutdownProtobufLibrary(); - - return retval; -} diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 95101c8075fc2..1b06eb55afbd2 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -138,7 +138,6 @@ class FuseExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override { // Fuse two add into one. std::vector> result; diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 8f4eede76b905..b6b915f90d99a 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -27,7 +27,6 @@ #include "test/util/include/default_providers.h" #include "test/util/include/file_util.h" #include "core/optimizer/layout_transformation/layout_transformation.h" -#include "core/optimizer/graph_optimizer_registry.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -265,11 +264,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup - auto graph_optimizer_registry = std::make_unique(&sess_options, - execution_providers.Get(onnxruntime::kCpuExecutionProvider), - &DefaultLoggingManager().DefaultLogger()); - GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK( partitioner.Partition( graph, session_state.GetMutableFuncMgr(), @@ -355,12 +350,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup - auto graph_optimizer_registry = std::make_unique(&sess_options, - execution_providers.Get(onnxruntime::kCpuExecutionProvider), - &DefaultLoggingManager().DefaultLogger()); // Partition the graph - GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK(partitioner.Partition( graph, session_state.GetMutableFuncMgr(), [&cpu_allocator](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, @@ -418,13 +409,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup - auto graph_optimizer_registry = std::make_unique(&sess_options, - execution_providers.Get(onnxruntime::kCpuExecutionProvider), - &DefaultLoggingManager().DefaultLogger()); - // Partition the graph - GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK(partitioner.Partition( graph, session_state.GetMutableFuncMgr(), [&cpu_allocator](Graph& graph, bool& modified, @@ -493,12 +479,7 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, SessionState session_state(model->MainGraph(), execution_providers, tp.get(), nullptr, dtm, edlm, default_logger, profiler, sess_options); - // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup - auto graph_optimizer_registry = std::make_unique(&sess_options, - execution_providers.Get(onnxruntime::kCpuExecutionProvider), - &DefaultLoggingManager().DefaultLogger()); - - GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(krm, execution_providers); layout_transformation::TransformLayoutFunction transform_layout_fn; layout_transformation::DebugGraphFn debug_graph_fn; ASSERT_STATUS_OK( diff --git a/onnxruntime/test/framework/type_info_test.cc b/onnxruntime/test/framework/type_info_test.cc index d8ef668bf1c7e..ee787fb071d97 100644 --- a/onnxruntime/test/framework/type_info_test.cc +++ b/onnxruntime/test/framework/type_info_test.cc @@ -22,9 +22,9 @@ TEST(TypeInfoTests, TensorProto) { auto tensor_type_info = OrtTypeInfo::FromTypeProto(tensor_type.value); ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info->type); - ASSERT_NE(nullptr, tensor_type_info->tensor_type_info); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->tensor_type_info->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->tensor_type_info->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info->data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->data->shape.GetDims())); } TEST(TypeInfoTests, SequenceWithTensorElement) { @@ -37,9 +37,9 @@ TEST(TypeInfoTests, SequenceWithTensorElement) { const auto& tensor_type_info = *seq_type_info->sequence_type_info->sequence_key_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); - ASSERT_NE(nullptr, tensor_type_info.tensor_type_info); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info.data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); } TEST(TypeInfoTests, OptionalWithTensorProto) { @@ -54,9 +54,9 @@ TEST(TypeInfoTests, OptionalWithTensorProto) { const auto& contained_type = *optional_type_info->optional_type_info->contained_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, contained_type.type); - ASSERT_NE(nullptr, contained_type.tensor_type_info); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.tensor_type_info->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.tensor_type_info->shape.GetDims())); + ASSERT_NE(nullptr, contained_type.data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.data->shape.GetDims())); } #if !defined(DISABLE_ML_OPS) @@ -74,11 +74,11 @@ TEST(TypeInfoTests, MapWithTensorValue) { const auto& tensor_type_info = *map_info.map_value_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); - ASSERT_NE(nullptr, tensor_type_info.tensor_type_info); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info.data); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); } #endif } // namespace test -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index eecff3fa4d8ff..6bfe7bc3856ba 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -174,7 +174,7 @@ static std::unique_ptr MakeSparseTensor(MLDataType data_type, cons return p_tensor; } -void BaseTester::CopyDataToTensor(gsl::span data, Tensor& dst) { +void BaseTester::CopyDataToTensor(gsl::span data, Tensor& dst) { ORT_ENFORCE(dst.SizeInBytes() >= data.size_bytes(), "Not enough space in the destination tensor"); memcpy(dst.MutableDataRaw(), data.data(), data.size_bytes()); } @@ -203,7 +203,7 @@ void BaseTester::AddSparseCooTensorData(std::vector& data, MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span indices, const ValidateOutputParams& check_params, const std::vector* dim_params) { @@ -247,7 +247,7 @@ void BaseTester::AddSparseCsrTensorData(std::vector& data, MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span inner_indices, gsl::span outer_indices, const ValidateOutputParams& check_params, diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h index d39cc3c750dec..512b3402c5986 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/providers/base_tester.h @@ -868,7 +868,7 @@ class BaseTester { void AddShapeToTensorData(NodeArg& node_arg, gsl::span dims, const std::vector* dim_params); - void CopyDataToTensor(gsl::span data, Tensor& dst); + void CopyDataToTensor(gsl::span data, Tensor& dst); #if !defined(DISABLE_SPARSE_TENSORS) NodeArg MakeSparseNodeArg(int32_t dtype, const char* name, @@ -879,7 +879,7 @@ class BaseTester { MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span indices, const ValidateOutputParams& check_params, const std::vector* dim_params = nullptr); @@ -895,7 +895,7 @@ class BaseTester { MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span inner_indices, gsl::span outer_indices, const ValidateOutputParams& check_params, diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 1c6375ebdb0b1..6f7930f722564 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -170,11 +170,11 @@ TEST(SoftmaxOperator, ThreeAndFourDimsAxis0) { RunTest(input_vals_60, expected_vals, three_dimensions, /*opset*/ 7, /*axis*/ 0, // axis=0 is not supported by TensorRT - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); RunTest(input_vals_60, expected_vals, four_dimensions, /*opset*/ 7, /*axis*/ 0, // axis=0 is not supported by TensorRT - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis) { @@ -201,10 +201,10 @@ TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis) { 0.040478885f, 0.033857856f, 0.080346674f, 0.06199841f, 0.040481992f}; RunTest(input_vals_60, expected_vals, three_dimensions, /*opset*/ 7, /*axis*/ 1, - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); RunTest(input_vals_60, expected_vals, four_dimensions, /*opset*/ 7, /*axis*/ 2, - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis_opset13) { @@ -376,9 +376,8 @@ TEST(SoftmaxOperator, DimWithZero) { RunTest(x_vals, expected_vals, dimensions, /*opset*/ -1, /*axis*/ 0, {kTensorrtExecutionProvider, - kNnapiExecutionProvider, // NNAPI softmax does not support empty input - kWebGpuExecutionProvider, // WebGPU does not support dim 0 - kQnnExecutionProvider} // QNN doesn't support dim 0 + kNnapiExecutionProvider, // NNAPI softmax does not support empty input + kQnnExecutionProvider} // QNN doesn't support dim 0 ); } diff --git a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc index c98d9e28b2f46..a5378fa3cefd7 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc @@ -254,45 +254,5 @@ TEST(ConvIntegerTest, WithStride3_2D_u8u8) { test.Run(); } -TEST(ConvIntegerTest, NoXZeroPoint) { - OpTester test("ConvInteger", 10); - std::vector x_dims{1, 1, 3, 3}; - test.AddInput("x", x_dims, - {2, 3, 4, - 5, 6, 7, - 8, 9, 10}); - std::vector w_dims{1, 1, 2, 2}; - test.AddInput("w", w_dims, - {2, 2, - 2, 2}); - test.AddOptionalInputEdge(); - test.AddInput("w_zero_point", {}, {1}); - std::vector y_dims{1, 1, 2, 2}; - test.AddOutput("y", y_dims, - {16, 20, - 28, 32}); - test.Run(); -} - -// provide optional input with empty name for w. tests that input args == 4 but the w_zero_point does not exist. -TEST(ConvIntegerTest, NoWZeroPoint) { - OpTester test("ConvInteger", 10); - std::vector x_dims{1, 1, 3, 3}; - test.AddInput("x", x_dims, - {2, 3, 4, - 5, 6, 7, - 8, 9, 10}); - std::vector w_dims{1, 1, 2, 2}; - test.AddInput("w", w_dims, - {2, 2, - 2, 2}); - test.AddInput("x_zero_point", {}, {1}); - test.AddOptionalInputEdge(); - std::vector y_dims{1, 1, 2, 2}; - test.AddOutput("y", y_dims, - {24, 32, - 48, 56}); - test.Run(); -} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index ee0aff6d26444..b753bc386d722 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -111,7 +111,6 @@ DataLayout InternalTestingExecutionProvider::GetPreferredLayout() const { std::vector> InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // find nodes that have ops in our supported list std::unordered_set supported_static_nodes; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index 0caa0febc2796..d2ed8259ee974 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -20,7 +20,6 @@ class InternalTestingExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, const IKernelLookup& /*kernel_lookup*/, - const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 3dec74599abdf..07843c30a61df 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -43,35 +43,6 @@ static const std::string& GetNodeAttr(const Node& node, const std::string& attr_ return default_val; } -// from the context cache Onnx model, find the EPContext node with main_context=1, -// and get the QNN context binary file name -static void GetContextBinaryFileName(const std::string onnx_ctx_file, - std::string& last_ctx_bin_file, - const Logger& logger) { - std::shared_ptr ctx_model; - ASSERT_STATUS_OK(Model::Load(ToPathString(onnx_ctx_file), ctx_model, nullptr, logger)); - auto& ctx_graph = ctx_model->MainGraph(); - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - if (1 == is_main_context) { - last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); - return; - } - } - } -} - -// Get context binary file name from Context model file and remove it with the context model file -void CleanUpCtxFile(std::string context_file_path) { - std::string qnn_ctx_binary_file_name; - GetContextBinaryFileName(context_file_path, qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); - - ASSERT_EQ(std::remove(qnn_ctx_binary_file_name.c_str()), 0); - ASSERT_EQ(std::remove(context_file_path.c_str()), 0); -} - // Create a model with FusedMatMul + Add (quantized) // input1 -> Add -> Q -> DQ ---- // | @@ -152,22 +123,22 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - const std::string context_model_file = "./qnn_context_binary_multi_partition_test.onnx"; - std::remove(context_model_file.c_str()); + const std::string context_binary_file = "./qnn_context_binary_multi_partition_test.onnx"; + std::remove(context_binary_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); int ep_context_node_count = 0; int non_ep_context_node_count = 0; std::shared_ptr ctx_model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); auto& ctx_graph = ctx_model->MainGraph(); for (auto& node : ctx_graph.Nodes()) { if (node.OpType() == "EPContext") { @@ -185,7 +156,7 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { Ort::SessionOptions so2; // context file path is required if it's non-embed mode and the model is loaded from memory - so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); + so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); so2.AppendExecutionProvider("QNN", provider_options); std::string ctx_model_data; @@ -193,7 +164,7 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { Ort::Session session2(*ort_env, ctx_model_data.data(), ctx_model_data.size(), so2); // clean up - CleanUpCtxFile(context_model_file); + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary @@ -266,7 +237,7 @@ void EpCtxCpuNodeWithExternalIniFileTestBody(bool expect_external_ini_file) { // clean up ASSERT_EQ(std::remove(model_with_ext.c_str()), 0); ASSERT_EQ(std::remove(model_ext_file_full_path.c_str()), 0); - CleanUpCtxFile(ep_context_model_file); + ASSERT_EQ(std::remove(ep_context_model_file.c_str()), 0); } // Set the external initializer size threshold to 1024 so FusedMatMul (which fallback on CPU) @@ -362,7 +333,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx"; - const std::string ep_context_binary_file = "./ep_context_no_over_write_QNN_10880527342279992768_1_0.bin"; + const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin"; std::remove(ep_context_onnx_file.c_str()); Ort::SessionOptions so; @@ -473,21 +444,21 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - const std::string context_model_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; - std::remove(context_model_file.c_str()); + const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; + std::remove(context_binary_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // clean up - CleanUpCtxFile(context_model_file); + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } // Generate context cache model from the ONNX models with 2 inputs. @@ -510,26 +481,26 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - const std::string context_model_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); auto inputs = model->MainGraph().GetInputs(); EXPECT_TRUE(inputs.size() == 2); EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); // clean up - CleanUpCtxFile(context_model_file); + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { @@ -548,20 +519,20 @@ TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - const std::string context_model_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); so.AddConfigEntry(kOrtSessionOptionEpContextNodeNamePrefix, node_name_prefix.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); for (auto& node : model->MainGraph().Nodes()) { if (node.OpType() == "EPContext") { EXPECT_TRUE(node.Name().find(node_name_prefix) != std::string::npos); @@ -569,7 +540,7 @@ TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { } // clean up - CleanUpCtxFile(context_model_file); + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } // Run QDQ model on HTP 3 times @@ -583,12 +554,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_model_file = "./qnn_context_binary_test.onnx"; - std::remove(context_model_file.c_str()); + const std::string context_binary_file = "./qnn_context_binary_test.onnx"; + std::remove(context_binary_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); const std::string op_type = "Atan"; @@ -606,11 +577,9 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { session_option_pairs); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // 2nd run directly loads and run from Qnn context cache model - std::unordered_map session_option_pairs2; - session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, @@ -618,10 +587,9 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_model_file, - session_option_pairs2); + context_binary_file); // Clean up - CleanUpCtxFile(context_model_file); + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } // Run QDQ model on HTP 3 times @@ -636,7 +604,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; const std::string context_binary_file = "./testdata/qnn_context_cache_non_embed.onnx"; - std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; + std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); @@ -718,7 +686,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { #endif provider_options["offload_graph_io_quantization"] = "0"; const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::filesystem::path context_bin = "qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; std::remove(context_binary_file.c_str()); std::remove(context_bin.string().c_str()); @@ -860,7 +828,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, "./qnn_context_not_exist.onnx")); RunOptions run_options; run_options.run_tag = so.session_logid; @@ -874,7 +841,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); @@ -887,7 +854,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, "./test_ctx.onnx")); RunOptions run_options; run_options.run_tag = so.session_logid; @@ -901,7 +867,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); @@ -918,12 +884,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_model_file = "./qnn_context_binary_2inputs_test.onnx"; - std::remove(context_model_file.c_str()); + const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; + std::remove(context_binary_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); @@ -942,11 +908,9 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { session_option_pairs); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // 2nd run directly loads and run from Qnn context cache model - std::unordered_map session_option_pairs2; - session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), provider_options, @@ -954,10 +918,9 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_model_file, - session_option_pairs2); + context_binary_file); // Clean up - CleanUpCtxFile(context_model_file); + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } // Context binary only contains a single QNN graph, generated context cache model (detached mode) only has 1 EPContext node @@ -973,14 +936,14 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_model_file = "./qnn_context_cache_non_embed.onnx"; - std::filesystem::path context_bin = "qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; - std::remove(context_model_file.c_str()); + const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::remove(context_binary_file.c_str()); std::remove(context_bin.string().c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); @@ -999,7 +962,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName session_option_pairs); // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // Check the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_bin)); @@ -1027,19 +990,18 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str())); RunOptions run_options; run_options.run_tag = so.session_logid; InferenceSessionWrapper session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::OK); // Clean up - ASSERT_EQ(std::remove(context_model_file.c_str()), 0); + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); ASSERT_EQ(std::remove(context_bin.string().c_str()), 0); } @@ -1091,20 +1053,44 @@ static void CreateQdqModel(const std::string& model_file_name, const Logger& log static void DumpModelWithSharedCtx(const ProviderOptions& provider_options, const std::string& onnx_model_path1, const std::string& onnx_model_path2) { - Ort::SessionOptions so; - so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); - // enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions - so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); + RunOptions run_options; + run_options.run_tag = so.session_logid; - so.AppendExecutionProvider("QNN", provider_options); + auto qnn_ep = QnnExecutionProviderWithOptions(provider_options, &so); + std::shared_ptr qnn_ep_shared(std::move(qnn_ep)); + + InferenceSessionWrapper session_object1{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object1.RegisterExecutionProvider(qnn_ep_shared)); + ASSERT_STATUS_OK(session_object1.Load(ToPathString(onnx_model_path1))); + ASSERT_STATUS_OK(session_object1.Initialize()); - // Create 2 sessions to generate context binary models, the 1st session will share the QnnBackendManager - // to the 2nd session, so graphs from these 2 models are all included in the 2nd context binary - Ort::Session session1(*ort_env, ToPathString(onnx_model_path1).c_str(), so); + InferenceSessionWrapper session_object2{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(qnn_ep_shared)); + ASSERT_STATUS_OK(session_object2.Load(ToPathString(onnx_model_path2))); + ASSERT_STATUS_OK(session_object2.Initialize()); +} - so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); - Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); +// from the last context ache Onnx model, find the EPContext node with main_context=1, +// and get the QNN context binary file name, thie context binary contains all graphs from all Onnx models +static void GetLastContextBinaryFileName(const std::string last_onnx_ctx_file, + std::string& last_ctx_bin_file, + const Logger& logger) { + std::shared_ptr ctx_model; + ASSERT_STATUS_OK(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, logger)); + auto& ctx_graph = ctx_model->MainGraph(); + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); + if (1 == is_main_context) { + last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); + return; + } + } + } } // Update generated context cache Onnx model to make the main EPContext node point to @@ -1181,21 +1167,15 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) { for (auto model_path : onnx_model_paths) { CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); - auto pos = model_path.find_last_of("."); - if (pos != std::string::npos) { - model_path = model_path.substr(0, pos) + "_ctx.onnx"; - } else { - model_path = model_path + "_ctx.onnx"; - } - ctx_model_paths.push_back(model_path); + ctx_model_paths.push_back(model_path + "_ctx.onnx"); } DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); - // Get the last context binary file name, the latest context binary file holds all graphs generated from all models + // Get the last context binary file name std::string last_qnn_ctx_binary_file_name; - GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); + GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty()); // Update generated context cache Onnx model to make the main EPContext node point to @@ -1285,21 +1265,15 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { for (auto model_path : onnx_model_paths) { CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); - auto pos = model_path.find_last_of("."); - if (pos != std::string::npos) { - model_path = model_path.substr(0, pos) + "_ctx.onnx"; - } else { - model_path = model_path + "_ctx.onnx"; - } - ctx_model_paths.push_back(model_path); + ctx_model_paths.push_back(model_path + "_ctx.onnx"); } DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); // Get the last context binary file name std::string last_qnn_ctx_binary_file_name; - GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); + GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty()); // Update generated context cache Onnx model to make the main EPContext node point to @@ -1362,69 +1336,6 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { } std::remove(last_qnn_ctx_binary_file_name.c_str()); } - -// For Ort sessions to generate the context binary, with session option ep.share_ep_contexts enabled -// Ort sessions will share the QnnBackendManager, so that all graphs from all models compile into the same Qnn context -TEST_F(QnnHTPBackendTests, QnnContextGenWeightSharingSessionAPI) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - provider_options["offload_graph_io_quantization"] = "0"; - - // Create QDQ models - std::vector onnx_model_paths{"./weight_share1.onnx", "./weight_share2.onnx"}; - std::vector ctx_model_paths; - for (auto model_path : onnx_model_paths) { - CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); - EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); - auto pos = model_path.find_last_of("."); - if (pos != std::string::npos) { - model_path = model_path.substr(0, pos) + "_ctx.onnx"; - } else { - model_path = model_path + "_ctx.onnx"; - } - ctx_model_paths.push_back(model_path); - } - - Ort::SessionOptions so; - so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); - // enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions - so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); - - so.AppendExecutionProvider("QNN", provider_options); - - Ort::Session session1(*ort_env, ToPathString(onnx_model_paths[0]).c_str(), so); - std::string qnn_ctx_binary_file_name1; - GetContextBinaryFileName(ctx_model_paths[0], qnn_ctx_binary_file_name1, - DefaultLoggingManager().DefaultLogger()); - EXPECT_TRUE(!qnn_ctx_binary_file_name1.empty()); - - // Tell the EP stop share the QnnBackendManager from this session then on - so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); - Ort::Session session2(*ort_env, ToPathString(onnx_model_paths[1]).c_str(), so); - std::string qnn_ctx_binary_file_name2; - GetContextBinaryFileName(ctx_model_paths[1], qnn_ctx_binary_file_name2, - DefaultLoggingManager().DefaultLogger()); - EXPECT_TRUE(!qnn_ctx_binary_file_name2.empty()); - - auto file_size_1 = std::filesystem::file_size(qnn_ctx_binary_file_name1); - auto file_size_2 = std::filesystem::file_size(qnn_ctx_binary_file_name2); - EXPECT_TRUE(file_size_2 > file_size_1); - - // clean up - for (auto model_path : onnx_model_paths) { - ASSERT_EQ(std::remove(model_path.c_str()), 0); - } - for (auto ctx_model_path : ctx_model_paths) { - ASSERT_EQ(std::remove(ctx_model_path.c_str()), 0); - } - ASSERT_EQ(std::remove(qnn_ctx_binary_file_name1.c_str()), 0); - ASSERT_EQ(std::remove(qnn_ctx_binary_file_name2.c_str()), 0); -} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 2361e179d1cf1..e2deccc4fff0f 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -14,7 +14,6 @@ #include "core/framework/compute_capability.h" #include "core/graph/graph.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/optimizer/graph_optimizer_registry.h" namespace onnxruntime { namespace test { @@ -280,10 +279,9 @@ static BackendSupport GetHTPSupport(const onnxruntime::logging::Logger& logger) onnxruntime::GraphViewer graph_viewer(graph); std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( {{"backend_path", "QnnHtp.dll"}, {"offload_graph_io_quantization", "0"}}); - GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability qnn_ep->SetLogger(&logger); - auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); + auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, nullptr); return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; } @@ -344,10 +342,9 @@ static BackendSupport GetCPUSupport(const onnxruntime::logging::Logger& logger) onnxruntime::GraphViewer graph_viewer(graph); std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( {{"backend_path", "QnnCpu.dll"}, {"offload_graph_io_quantization", "0"}}); - GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability qnn_ep->SetLogger(&logger); - auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); + auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, nullptr); return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; } diff --git a/onnxruntime/test/python/quantization/test_get_qdq_config.py b/onnxruntime/test/python/quantization/test_get_qdq_config.py index 4a71b3694822c..25f058d8f6eac 100644 --- a/onnxruntime/test/python/quantization/test_get_qdq_config.py +++ b/onnxruntime/test/python/quantization/test_get_qdq_config.py @@ -156,62 +156,6 @@ def should_exclude_node_(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: self.assertTrue(bool(expected_excluded_nodes)) self.assertEqual(set(qdq_config.nodes_to_exclude), expected_excluded_nodes) - def test_op_types_to_quantize(self): - """ - Test that get_qdq_config() returns a config that sets the op_types_to_quantize arg. - """ - shape = [1, 8, 8] - tensor_type = onnx.TensorProto.FLOAT - np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) - weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") - float_model = self.build_add_model(shape, tensor_type, weight) - - input_data_list = [ - {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, - {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, - ] - data_reader = TestDataFeeds(input_data_list) - - # No op_types_to_quantize arg means all ops are quantized. - qdq_config = get_qdq_config(float_model, data_reader, op_types_to_quantize=None) - self.assertEqual(set(qdq_config.op_types_to_quantize), {"Add"}) - - # specify custom op_types_to_quantize arg. - qdq_config = get_qdq_config(float_model, data_reader, op_types_to_quantize=["Mul"]) - self.assertEqual(set(qdq_config.op_types_to_quantize), {"Mul"}) - - # exclude op_type indirectly by specifying nodes_to_exclude arg. - qdq_config = get_qdq_config( - float_model, - data_reader, - nodes_to_exclude=[node.name for node in float_model.graph.node if node.op_type == "Add"], - ) - self.assertEqual(set(qdq_config.op_types_to_quantize), set()) - - def test_calibration_providers(self): - """ - Test that get_qdq_config() returns a config that sets the calibration providers arg. - """ - - shape = [1, 8, 8] - tensor_type = onnx.TensorProto.FLOAT - np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) - weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") - float_model = self.build_add_model(shape, tensor_type, weight) - - input_data_list = [ - {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, - {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, - ] - data_reader = TestDataFeeds(input_data_list) - - qdq_config = get_qdq_config( - float_model, - data_reader, - calibration_providers=["CPUExecutionProvider"], - ) - self.assertEqual(qdq_config.calibration_providers, ["CPUExecutionProvider"]) - def test_external_data(self): """ Test that get_qdq_config() returns a config that enables external data diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md b/onnxruntime/test/qnn_ctx_gen/README.md similarity index 82% rename from onnxruntime/test/ep_weight_sharing_ctx_gen/README.md rename to onnxruntime/test/qnn_ctx_gen/README.md index be1a1fe039366..97ab89d79cbd2 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md +++ b/onnxruntime/test/qnn_ctx_gen/README.md @@ -2,19 +2,17 @@ This tool provides the way to generate Onnx models that wraps QNN context binary warpt with weight sharing enabled. The options to use with the tool are listed below: -`ep_weight_sharing_ctx_gen [options...] model_1_path,model_2_path` +`onnxruntime_qnn_ctx_gen [options...] model_path,model_path` -./ep_weight_sharing_ctx_gen -e qnn -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" /mnt/c/model1.onnx,/mnt/c/model2.onnx +./onnxruntime_qnn_ctx_gen -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" -C "ep.context_enable|1 ep.context_embed_mode|0" /mnt/c/model1.onnx,/mnt/c/model2.onnx Options: - - -e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider qnn, tensorrt, openvino, vitisai. Default is qnn. - + -v: Show verbose information. -C: [session_config_entries]: Specify session configuration entries as key-value pairs: -C "| |" Refer to onnxruntime_session_options_config_keys.h for valid keys and values. - [Example] -C "ep.context_enable|1 ep.context_embed_mode|0". These are set as default so can be ignored. + [Example] -C "ep.context_enable|1 ep.context_embed_mode|0" -i: [provider_options]: Specify QNN EP specific runtime options as key value pairs. Different runtime options available are: [Usage]: -i '| |' diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc similarity index 68% rename from onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc rename to onnxruntime/test/qnn_ctx_gen/command_args_parser.cc index bf21d54ccde41..24c343c7b9541 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "command_args_parser.h" @@ -28,30 +29,28 @@ namespace qnnctxgen { /*static*/ void CommandLineParser::ShowUsage() { printf( - "ep_weight_sharing_ctx_gen [options...] model1_path,model2_path\n" - "Example: ./ep_weight_sharing_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n" + "onnxruntime_qnn_ctx_gen [options...] model1_path,model2_path\n" + "Example: ./onnxruntime_qnn_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n" "Options:\n" - "\t-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn','tensorrt','openvino', 'vitisai'. " - "Default:'qnn'.\n" "\t-v: Show verbose information.\n" "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" "\t Force ep.context_enable to 1 and ep.context_embed_mode to 0. Change ep.context_file_path is not allowed." "\t [Example] -C \"ep.context_node_name_prefix|_part1\" \n" - "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" + "\t-i: Specify QNN EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [Usage]: -i '| |'\n" "\n" - "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. default to HTP backend\n" - "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" - "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: '0', '1', '2', '3', default is '0'.\n" - "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" - "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" - "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" + "\t [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. default to HTP backend\n" + "\t [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" + "\t [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: '0', '1', '2', '3', default is '0'.\n" + "\t [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" - "\t [QNN only] [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" - "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" - "\t Defaults to '1' (QNN EP handles the graph I/O quantization and dequantization). \n" - "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." + "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" + "\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -110,22 +109,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(TestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("e:o:u:i:C:vh"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("o:u:i:C:vh"))) != -1) { switch (ch) { - case 'e': - if (!CompareCString(optarg, ORT_TSTR("qnn"))) { - test_config.machine_config.provider_type_name = onnxruntime::kQnnExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("openvino"))) { - test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("tensorrt"))) { - test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("vitisai"))) { - test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider; - } else { - fprintf(stderr, "The execution provider is not included in this tool.\n"); - return false; - } - break; case 'v': test_config.run_config.f_verbose = true; break; @@ -177,7 +162,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])"); } - test_config.run_config.provider_options[key] = value; + test_config.run_config.qnn_options[key] = value; } break; } diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.h b/onnxruntime/test/qnn_ctx_gen/command_args_parser.h similarity index 100% rename from onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.h rename to onnxruntime/test/qnn_ctx_gen/command_args_parser.h diff --git a/onnxruntime/test/qnn_ctx_gen/main.cc b/onnxruntime/test/qnn_ctx_gen/main.cc new file mode 100644 index 0000000000000..bb5007b40b072 --- /dev/null +++ b/onnxruntime/test/qnn_ctx_gen/main.cc @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// onnxruntime dependencies +#include "test_configuration.h" +#include +#include +#include +#include "command_args_parser.h" +#include + +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/inference_session.h" +#include "core/session/ort_env.h" +#include "core/providers/provider_factory_creators.h" +#include "core/common/logging/sinks/clog_sink.h" + +#include "core/graph/model.h" +#include "core/session/environment.h" +#include "core/common/logging/logging.h" + +using namespace onnxruntime; +const OrtApi* g_ort = NULL; +std::unique_ptr ort_env; + +static void CheckStatus(const Status& status) { + if (status.Code() != common::StatusCode::OK) { + std::string msg = status.ErrorMessage(); + throw Ort::Exception(std::move(msg), OrtErrorCode::ORT_FAIL); + } +} + +static int64_t GetNodeAttr(const Node& node, const std::string& attr_name, int64_t default_val) { + const auto& attributes = node.GetAttributes(); + if (auto entry = attributes.find(attr_name); entry != attributes.end()) { + return entry->second.i(); + } + + return default_val; +} + +static const std::string& GetNodeAttr(const Node& node, const std::string& attr_name, const std::string& default_val) { + const auto& attributes = node.GetAttributes(); + if (auto entry = attributes.find(attr_name); entry != attributes.end()) { + return entry->second.s(); + } + + return default_val; +} + +// from the last context cache Onnx model, find the EPContext node with main_context=1, +// and get the QNN context binary file name, this context binary contains all graphs from all Onnx models +// get the max spill fill buffer size +static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, + std::string& last_ctx_bin_file, + int64_t& max_size) { + max_size = 0; + std::shared_ptr ctx_model; + CheckStatus(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, + (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); + auto& ctx_graph = ctx_model->MainGraph(); + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); + max_size = GetNodeAttr(node, "max_size", static_cast(0)); + if (1 == is_main_context) { + last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); + return; + } + } + } +} + +// Update generated context cache Onnx model to make the main EPContext node point to +// the last QNN context binary file +// Remove not used QNN context binary file, only keep the last one which contains all graphs +static void UpdateEpContextModel(const std::vector>& ep_ctx_files, + const std::string& last_qnn_ctx_binary_file_name, + int64_t max_size) { + for (auto ep_ctx_file : ep_ctx_files) { + std::shared_ptr ctx_model; + auto path_str = ToPathString(ep_ctx_file); + CheckStatus(Model::Load(path_str, ctx_model, nullptr, + (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); + auto& ctx_graph = ctx_model->MainGraph(); + GraphViewer graph_viewer(ctx_graph); + auto path = std::filesystem::path(path_str); + + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); + if (1 == is_main_context) { + std::string old_qnn_ctx_binary_file_name = GetNodeAttr(node, "ep_cache_context", ""); + auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); + std::remove(file_path.string().c_str()); + node.ClearAttribute("ep_cache_context"); + node.AddAttribute("ep_cache_context", last_qnn_ctx_binary_file_name); + node.ClearAttribute("max_size"); + node.AddAttribute("max_size", max_size); + } + } + } + std::remove(ToUTF8String(ep_ctx_file).c_str()); + CheckStatus(Model::Save(*ctx_model.get(), ToPathString(ep_ctx_file))); + } +} + +#ifdef _WIN32 +int real_main(int argc, wchar_t* argv[]) { +#else +int real_main(int argc, char* argv[]) { +#endif + g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); + qnnctxgen::TestConfig test_config; + if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { + qnnctxgen::CommandLineParser::ShowUsage(); + return -1; + } + + { + bool failed = false; + ORT_TRY { + OrtLoggingLevel logging_level = test_config.run_config.f_verbose + ? ORT_LOGGING_LEVEL_VERBOSE + : ORT_LOGGING_LEVEL_WARNING; + + ort_env = std::make_unique(logging_level, "Default"); + } + ORT_CATCH(const Ort::Exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + fprintf(stderr, "Error creating environment. Error-> %s \n", e.what()); + failed = true; + }); + } + + if (failed) + return -1; + } + + ORT_TRY { + SessionOptions so; + so.session_logid = "qnn_ctx_gen_session_logger"; + // Set default session option to dump QNN context model with non-embed mode + CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); + CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); + RunOptions run_options; + run_options.run_tag = so.session_logid; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + // set default QNN EP option to enable weight sharing + provider_options["enable_htp_weight_sharing"] = "1"; + + for (auto it : test_config.run_config.qnn_options) { + provider_options[it.first] = it.second; + } + + for (auto it : test_config.run_config.session_config_entries) { + if (it.first == kOrtSessionOptionEpContextEnable && it.second != "1") { + std::cerr << "Need to enable ep context cache." << std::endl; + continue; + } + if (it.first == kOrtSessionOptionEpContextEmbedMode && it.second != "0") { + std::cerr << "Only support non-embed model for weight sharing." << std::endl; + continue; + } + if (it.first == kOrtSessionOptionEpContextFilePath) { + std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; + continue; + } + CheckStatus(so.config_options.AddConfigEntry(it.first.c_str(), it.second.c_str())); + } + + for (auto model_path : test_config.model_file_paths) { + std::cout << "Model file path: " << ToUTF8String(model_path) << std::endl; + } + + // Generate context cache model files with QNN context binary files + // The context binary file generated later includes all graphs from previous models + { + auto ep = QNNProviderFactoryCreator::Create(provider_options, &so)->CreateProvider(); + std::shared_ptr qnn_ep(std::move(ep)); + + for (auto model_path : test_config.model_file_paths) { + std::cout << "Generate context cache model for: " << ToUTF8String(model_path) << std::endl; + InferenceSession session_object{so, ((OrtEnv*)*ort_env.get())->GetEnvironment()}; + CheckStatus(session_object.RegisterExecutionProvider(qnn_ep)); + CheckStatus(session_object.Load(ToPathString(model_path))); + CheckStatus(session_object.Initialize()); + } + } + + std::cout << "Start to update the generated Onnx model." << std::endl; + std::vector> ep_ctx_files; + ep_ctx_files.reserve(test_config.model_file_paths.size()); + for (auto model_path : test_config.model_file_paths) { + ep_ctx_files.push_back(model_path + ORT_TSTR("_ctx.onnx")); + } + + // Get the last context binary file name + std::string last_qnn_ctx_binary_file_name; + int64_t max_size = 0; + GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); + std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; + if (last_qnn_ctx_binary_file_name.empty()) { + throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); + } + ep_ctx_files.pop_back(); + + // Update generated context cache Onnx model to make the main EPContext node point to + // the last QNN context binary file + // Remove not used QNN context binary file, only keep the last one which contains all graphs + UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); + } + ORT_CATCH(const Ort::Exception& e) { + fprintf(stderr, "Failed to generate context cache file: %s \n", e.what()); + + ort_env.reset(); + return -1; + } + + ort_env.reset(); + + return 0; +} + +#ifdef _WIN32 +int wmain(int argc, wchar_t* argv[]) { +#else +int main(int argc, char* argv[]) { +#endif + int retval = -1; + ORT_TRY { + retval = real_main(argc, argv); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + fprintf(stderr, "%s\n", ex.what()); + retval = -1; + }); + } + + ::google::protobuf::ShutdownProtobufLibrary(); + + return retval; +} diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h b/onnxruntime/test/qnn_ctx_gen/test_configuration.h similarity index 75% rename from onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h rename to onnxruntime/test/qnn_ctx_gen/test_configuration.h index 198d03211f561..bf4c7061a3484 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h +++ b/onnxruntime/test/qnn_ctx_gen/test_configuration.h @@ -14,20 +14,15 @@ namespace onnxruntime { namespace qnnctxgen { -struct MachineConfig { - std::string provider_type_name{onnxruntime::kQnnExecutionProvider}; -}; - struct RunConfig { bool f_verbose{false}; std::unordered_map session_config_entries; - std::unordered_map provider_options; + std::unordered_map qnn_options; }; struct TestConfig { std::vector> model_file_paths; RunConfig run_config; - MachineConfig machine_config; }; } // namespace qnnctxgen diff --git a/onnxruntime/test/shared_lib/custom_op_utils.cc b/onnxruntime/test/shared_lib/custom_op_utils.cc index a624479bcd00b..bf7efacdbb505 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.cc +++ b/onnxruntime/test/shared_lib/custom_op_utils.cc @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include "gtest/gtest.h" #include "custom_op_utils.h" @@ -640,22 +639,3 @@ void StandaloneCustomKernel::Compute(OrtKernelContext* context) { StandaloneCustomKernel::~StandaloneCustomKernel() { } - -OrtStatusPtr CustomCastKernel::ComputeV2(OrtKernelContext* context) { - Ort::KernelContext ctx(context); - - auto in = ctx.GetInput(0); - std::vector shape = in.GetTensorTypeAndShapeInfo().GetShape(); - int64_t num_elements = std::accumulate(shape.cbegin(), shape.cend(), int64_t(1), std::multiplies()); - - // CustomCast::GetInputType constraint ensures we only get float input - const float* data = in.GetTensorData(); - double* out_data = ctx.GetOutput(0, shape).GetTensorMutableData(); - gsl::span input_span(data, num_elements); - gsl::span output_span(out_data, num_elements); - - std::transform(input_span.begin(), input_span.end(), output_span.begin(), - [](float val) { return static_cast(val); }); - - return nullptr; -} diff --git a/onnxruntime/test/shared_lib/custom_op_utils.h b/onnxruntime/test/shared_lib/custom_op_utils.h index 424c2e2fe3a08..e11540aaa5691 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.h +++ b/onnxruntime/test/shared_lib/custom_op_utils.h @@ -8,6 +8,12 @@ #include #endif +struct Input { + const char* name = nullptr; + std::vector dims; + std::vector values; +}; + struct MyCustomKernel { MyCustomKernel(const OrtApi& ort_api, const OrtKernelInfo* /*info*/) : ort_(ort_api) { @@ -458,63 +464,4 @@ struct MulTopOpFloat16 : Ort::CustomOpBase OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL; } -}; - -// -// Example overriding an operator where type inference is required for the output so kernel matching works correctly -// -struct CustomCastKernel { - CustomCastKernel(const OrtApi& /*ort_api*/, const OrtKernelInfo* /*info*/) - /*: ort_(ort_api)*/ { - } - - OrtStatusPtr ComputeV2(OrtKernelContext* context); - - private: - // const OrtApi& ort_; -}; - -// Custom Cast op that takes float input and converts based on 'to' attribute. -// Example implementation only supports cast to double. -struct CustomCast : Ort::CustomOpBase { - explicit CustomCast(const char* provider) : provider_(provider) { - // if overriding an ONNX op you need to set the opset versions you are overriding - start_ver_ = 7; // should match minimum ONNX schema you implement - // end_ver_ = ...; should match maximum ONNX schema you implement or unset for unlimited. - } - - // static method used by Ort::CustomOpBase::SetShapeInferFn - static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context) { - auto shape = context.GetInputShape(0); - - // infer output type based on 'to'. - auto to = context.GetAttrInt("to"); - if (to != ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - return Ort::Status("Unexpected type", ORT_INVALID_ARGUMENT).release(); - } - - context.SetOutputShape(0, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE); - return nullptr; - } - - OrtStatusPtr CreateKernelV2(const OrtApi& api, const OrtKernelInfo* info, void** op_kernel) const { - Ort::ConstKernelInfo ki(info); - *op_kernel = new CustomCastKernel(api, info); - return nullptr; - }; - - const char* GetName() const { return "Cast"; }; - const char* GetExecutionProviderType() const { return provider_; }; - - size_t GetInputTypeCount() const { return 1; }; - ONNXTensorElementDataType GetInputType(size_t /*index*/) const { - // example only accepts float input - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - }; - - size_t GetOutputTypeCount() const { return 1; }; - ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; - - private: - const char* provider_{"CPUExecutionProvider"}; -}; +}; \ No newline at end of file diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index b517ba7032886..ca9ca0f82a25a 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1,19 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include -#include #include -#include +#include +#include +#include #include +#include +#include +#include #include -#include #include -#include - #include "gtest/gtest.h" #include "gmock/gmock.h" @@ -27,13 +25,13 @@ #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/util/thread_utils.h" -#include "test/shared_lib/custom_op_utils.h" -#include "test/shared_lib/test_fixture.h" -#include "test/shared_lib/utils.h" -#include "test/util/include/providers.h" -#include "test/util/include/test_allocator.h" - -#include "onnxruntime_config.h" // generated file in build output dir +#include "onnxruntime_config.h" +#include "providers.h" +#include "test_allocator.h" +#include "test_fixture.h" +#include "utils.h" +#include "custom_op_utils.h" +#include #ifdef _WIN32 #include @@ -65,6 +63,48 @@ constexpr size_t countof(T (&)[N]) { return N; } extern std::unique_ptr ort_env; +template +void RunSession(OrtAllocator* allocator, Ort::Session& session_object, + const std::vector& inputs, + const char* output_name, + const std::vector& dims_y, + const std::vector& values_y, + Ort::Value* output_tensor) { + std::vector ort_inputs; + std::vector input_names; + for (size_t i = 0; i < inputs.size(); i++) { + input_names.emplace_back(inputs[i].name); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), + inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); + } + + std::vector ort_outputs; + if (output_tensor) + session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, output_tensor, 1); + else { + ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, 1); + ASSERT_EQ(ort_outputs.size(), 1u); + output_tensor = &ort_outputs[0]; + } + + auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); + ASSERT_EQ(type_info.GetShape(), dims_y); + size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(values_y.size(), total_len); + + OutT* f = output_tensor->GetTensorMutableData(); + for (size_t i = 0; i != total_len; ++i) { + if constexpr (std::is_same::value || std::is_same::value) { + ASSERT_NEAR(values_y[i], f[i], 1e-3); + } else { + ASSERT_EQ(values_y[i], f[i]); + } + } +} + #ifdef USE_DML struct DmlObjects { ComPtr d3d12_device; @@ -260,12 +300,12 @@ Ort::Value CreateTensorValueFromExistingD3DResource( #endif -template > +template static void TestInference(Ort::Env& env, const std::basic_string& model_uri, const std::vector& inputs, const char* output_name, const std::vector& expected_dims_y, - const std::vector& expected_values_y, + const std::vector& expected_values_y, int provider_type, OrtCustomOpDomain* custom_op_domain_ptr, const ORTCHAR_T* custom_op_library_filename, @@ -322,26 +362,26 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod auto default_allocator = std::make_unique(); // without preallocated output tensor - RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims_y, - expected_values_y, - nullptr); + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + nullptr); // with preallocated output tensor - Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), - expected_dims_y.data(), expected_dims_y.size()); + Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), + expected_dims_y.data(), expected_dims_y.size()); // test it twice for (int i = 0; i != 2; ++i) - RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims_y, - expected_values_y, - &value_y); + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + &value_y); } } @@ -410,8 +450,8 @@ class CApiTestWithProvider : public testing::Test, public ::testing::WithParamIn TEST_P(CApiTestWithProvider, simple) { // simple inference test // prepare inputs - std::vector> inputs(1); - auto& input = inputs.back(); + std::vector inputs(1); + Input& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -581,8 +621,8 @@ TEST(CApiTest, SparseInputModel) { TEST(CApiTest, custom_op_handler) { std::cout << "Running custom op inference" << std::endl; - std::vector> inputs(1); - auto& input = inputs[0]; + std::vector inputs(1); + Input& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -617,8 +657,8 @@ TEST(CApiTest, custom_op_handler) { TEST(CApiTest, custom_op_set_input_memory_type) { std::cout << "Running custom op inference" << std::endl; - std::vector> inputs(1); - auto& input = inputs[0]; + std::vector inputs(1); + Input& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -647,8 +687,8 @@ TEST(CApiTest, custom_op_set_input_memory_type) { #if !defined(ORT_MINIMAL_BUILD) TEST(CApiTest, StandaloneOpHandler) { - std::vector> inputs(1); - auto& input = inputs[0]; + std::vector inputs(1); + Input& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -771,7 +811,7 @@ TEST(CApiTest, test_enable_ort_customops_stringlower) { // test custom op which accepts float and double as inputs TEST(CApiTest, varied_input_custom_op_handler) { - std::vector> inputs(2); + std::vector inputs(2); inputs[0].name = "X"; inputs[0].dims = {3}; inputs[0].values = {2.0f, 3.0f, 4.0f}; @@ -1382,8 +1422,8 @@ TEST(CApiTest, custom_op_with_attributes_handler) { TEST(CApiTest, RegisterCustomOpForCPUAndCUDA) { std::cout << "Tests registration of a custom op of the same name for both CPU and CUDA EPs" << std::endl; - std::vector> inputs(1); - auto& input = inputs[0]; + std::vector inputs(1); + Input& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -1491,7 +1531,7 @@ TEST(CApiTest, test_custom_op_openvino_wrapper_library) { // The custom op extracts the serialized .xml/.bin bytes and creates an in-memory OpenVINO model // during kernel creation. The custom op is passed an image of a hand-drawn "1" as an input during computation, which // is then inferenced using OpenVINO C++ APIs. - std::vector> inputs(1); + std::vector inputs(1); inputs[0].name = "Input3"; inputs[0].dims = {1, 1, 28, 28}; @@ -1590,7 +1630,7 @@ TEST(CApiTest, test_custom_op_library) { #endif std::cout << "Running inference using custom op shared library" << std::endl; - std::vector> inputs(2); + std::vector inputs(2); inputs[0].name = "input_1"; inputs[0].dims = {3, 5}; inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, @@ -1642,7 +1682,7 @@ TEST(CApiTest, DISABLED_test_custom_op_shape_infer_attr) { #else TEST(CApiTest, test_custom_op_shape_infer_attr) { #endif - std::vector> inputs(1); + std::vector inputs(1); inputs[0].name = "input_0"; inputs[0].dims = {5}; inputs[0].values = {1.f, 2.f, 3.f, 4.f, 5.f}; @@ -1675,7 +1715,7 @@ TEST(CApiTest, test_custom_op_library_copy_variadic) { #endif std::cout << "Running inference using custom op shared library" << std::endl; - std::vector> inputs(2); + std::vector inputs(2); inputs[0].name = "input_0"; inputs[0].dims = {15}; inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, @@ -1829,8 +1869,8 @@ void PrepareModule() { TEST(CApiTest, test_pyop) { std::call_once(my_module_flag, PrepareModule); - std::vector> inputs(1); - auto& input = inputs[0]; + std::vector inputs(1); + Input& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1842,8 +1882,8 @@ TEST(CApiTest, test_pyop) { TEST(CApiTest, test_pyop_multi) { std::call_once(my_module_flag, PrepareModule); - std::vector> inputs(1); - auto& input = inputs[0]; + std::vector inputs(1); + Input& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1855,8 +1895,8 @@ TEST(CApiTest, test_pyop_multi) { TEST(CApiTest, test_pyop_kwarg) { std::call_once(my_module_flag, PrepareModule); - std::vector> inputs(1); - auto& input = inputs[0]; + std::vector inputs(1); + Input& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1880,7 +1920,7 @@ TEST(ReducedOpsBuildTest, test_excluded_ops) { // In reduced ops build, test a model containing ops not included in required_ops.config cannot be loaded. // See onnxruntime/test/testdata/reduced_build_test.readme.txt for more details of the setup constexpr PATH_TYPE model_uri = TSTR("testdata/reduced_build_test.onnx_model_with_excluded_ops"); - std::vector> inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}}; + std::vector inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}}; std::vector expected_dims_y = {3}; std::vector expected_values_y = {0.1f, 0.1f, 0.1f}; bool failed = false; @@ -3282,8 +3322,8 @@ TEST(CApiTest, TestSharedAllocators) { OrtEnv* env_ptr = (OrtEnv*)(*ort_env); // prepare inputs - std::vector> inputs(1); - auto& input = inputs.back(); + std::vector inputs(1); + Input& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -3469,8 +3509,8 @@ TEST(CApiTest, TestSharedAllocators) { TEST(CApiTest, TestSharingOfInitializerAndItsPrepackedVersion) { // simple inference test // prepare inputs - std::vector> inputs(1); - auto& input = inputs.back(); + std::vector inputs(1); + Input& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -3865,8 +3905,8 @@ TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) { // simple inference test // prepare inputs - std::vector> inputs(1); - auto& input = inputs.back(); + std::vector inputs(1); + Input& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -4805,32 +4845,4 @@ TEST(CApiTest, GenerateNodeStatsFile) { output_names, 1); } -#endif - -// Test that creates a custom Cast kernel which requires type inference of the output type to work. -// Also demonstrates overriding an ONNX operator as we register the custom op in the ONNX domain. -TEST(CApiTest, custom_cast) { - std::vector> inputs(1); - auto& input = inputs[0]; - input.name = "input"; - input.dims = {3, 4}; - input.values = {1.0f, 2.0f, 3.0f, 4.0f, - -1.0f, -2.0f, -3.0f, -4.0f, - 1.0f, 2.0f, 3.0f, 4.0f}; - - // prepare expected inputs and outputs - std::vector expected_dims_y = {3, 4}; - std::vector expected_values_y = {1.0, 2.0, 3.0, 4.0, - -1.0, -2.0, -3.0, -4.0, - 1.0, 2.0, 3.0, 4.0}; - - CustomCast custom_op{onnxruntime::kCpuExecutionProvider}; - - Ort::CustomOpDomain custom_op_domain(""); // onnx domain is empty string - custom_op_domain.Add(&custom_op); - - // model with Cast from ONNX test data - TestInference(*ort_env, TSTR("testdata/cast_float_to_double.onnx"), - inputs, "output", expected_dims_y, expected_values_y, 0, - custom_op_domain, nullptr); -} +#endif \ No newline at end of file diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc deleted file mode 100644 index 9807fcca06ed4..0000000000000 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ /dev/null @@ -1,701 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "gtest/gtest.h" -#include "gmock/gmock.h" - -#include "core/common/narrow.h" -#include "core/graph/constants.h" -#include "core/session/onnxruntime_c_api.h" -#include "core/session/onnxruntime_cxx_api.h" -#include "core/session/onnxruntime_lite_custom_op.h" -#include "core/session/onnxruntime_session_options_config_keys.h" - -#include "test/shared_lib/test_fixture.h" -#include "test/shared_lib/utils.h" -#include "test/util/include/test_allocator.h" - -#include "onnxruntime_config.h" // generated file in build output dir - -extern std::unique_ptr ort_env; - -using namespace Ort; - -namespace { - -Ort::Session CreateSession(Ort::Env& env, - Model& graph_api_model, - Ort::SessionOptions* session_options_for_test = nullptr) { - Ort::SessionOptions default_session_options; - Ort::SessionOptions& session_options = session_options_for_test ? *session_options_for_test - : default_session_options; - - // Set this to save the model if you want to debug. - // session_options.SetOptimizedModelFilePath(ORT_TSTR("model_builder_output.onnx")); - - Ort::Session session(env, graph_api_model, session_options); - - // Session should not require the model to stay alive so free it now to validate. - graph_api_model = Model(nullptr); - - return session; -} - -template -void TestInference(Ort::Session& session, - const std::vector>& inputs, - const char* output_name, - const std::vector& expected_dims, - const std::vector& expected_values) { - auto default_allocator = std::make_unique(); - - // without preallocated output tensor - RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims, - expected_values, - nullptr); -} - -// Create OrtNode using the C API -OrtNode* CreateNode(const OrtModelEditorApi& api, - const char* operator_name, const char* node_name, - const gsl::span input_names, - const gsl::span output_names, - const gsl::span attributes = {}, - const char* domain_name = onnxruntime::kOnnxDomain) { - OrtNode* node = nullptr; - Ort::ThrowOnError(api.CreateNode(operator_name, domain_name, node_name, - input_names.data(), input_names.size(), - output_names.data(), output_names.size(), - attributes.data(), attributes.size(), - &node)); - return node; -} - -// convenience func to convert initalizer lists to gsl::span -OrtNode* CreateNode(const OrtModelEditorApi& api, - const char* operator_name, const char* node_name, - const std::initializer_list input_names, - const std::initializer_list output_names, - const std::initializer_list attributes = {}, - const char* domain_name = onnxruntime::kOnnxDomain) { - std::vector inputs(input_names); - std::vector outputs(output_names); - std::vector attrs(attributes); - return CreateNode(api, operator_name, node_name, inputs, outputs, attrs, domain_name); -} -} // namespace - -struct TestAllocator : public OrtAllocator { - TestAllocator() { - version = ORT_API_VERSION; - Info = [](const struct OrtAllocator* this_ptr) -> const struct OrtMemoryInfo* { - auto* test_allocator = static_cast(this_ptr); - return test_allocator->memory_info; - }; - - Free = [](struct OrtAllocator* allocator, void* p) -> void { - auto* test_allocator = static_cast(allocator); - // find the matching pointer and remove it - auto it = std::find_if(test_allocator->weights.begin(), test_allocator->weights.end(), - [p](const std::unique_ptr>& v) { return v->data() == p; }); - if (it == test_allocator->weights.end()) { - throw std::runtime_error("Free called with unknown pointer"); - } - - test_allocator->weights.erase(it); - }; - - Alloc = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { - throw std::runtime_error("This should not be used"); - }; - - Reserve = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { - throw std::runtime_error("This should not be used"); - }; - } - - // initializers that are used directly by the model. as there's no copy they must remain valid. - // we store them in the test allocator so we can validate that Free is called - std::vector>> weights; - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, - OrtMemType::OrtMemTypeDefault); -}; - -// Test the ModelEditorAPI C api -// Uses the ORT C++ api for the rest for simplicity -TEST(ModelEditorAPITest, Basic_CApi) { - const auto& api = Ort::GetApi(); - const auto& model_editor_api = Ort::GetModelEditorApi(); - - TestAllocator deleter; - - // return void so we can use ASSERT_* in the lambda - const auto build_model = [&](bool use_constant_node, OrtModel*& model) -> void { - OrtGraph* graph = nullptr; - Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); - - // - // Create OrtModel with a Gemm. X input is 3x4, Y input is 4x8, Z output is 3x8. - // X is model input. Y is initializer. - // Set the alpha attribute of the Gemm node to 2.0 to test attribute handling. - // - - // model input - OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; - std::vector input_dims = {3, 4}; - // can use api.SetSymbolicDimensions to set symbolic dimensions. - // the input array should have the same rank as the call to SetDimensions. - // e.g. call SetDimensions with {-1, 3, 2} and SetSymbolicDimensions with {"N", nullptr, nullptr} to create - // a shape of {"N", 3, 2} - - Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); - Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); - Ort::ThrowOnError(api.SetDimensions(tensor_type_info, input_dims.data(), input_dims.size())); - - OrtTypeInfo* input_type_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &input_type_info)); - api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); // input_type_info took a copy - - // create ValueInfo and release the type info as CreateValueInfo takes a copy. - OrtValueInfo* input_value_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateValueInfo("X", input_type_info, &input_value_info)); - api.ReleaseTypeInfo(input_type_info); // input_value_info took a copy - tensor_type_info = nullptr; - - // model outputs - OrtTypeInfo* output_type_info = nullptr; - std::vector output_dims = {3, 8}; - - Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); - Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); - Ort::ThrowOnError(api.SetDimensions(tensor_type_info, output_dims.data(), output_dims.size())); - - Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &output_type_info)); - api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); // input_type_info took a copy - - OrtValueInfo* output_value_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateValueInfo("Z", output_type_info, &output_value_info)); - api.ReleaseTypeInfo(output_type_info); - - std::vector graph_inputs = {input_value_info}; - std::vector graph_outputs = {output_value_info}; - Ort::ThrowOnError(model_editor_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size())); - Ort::ThrowOnError(model_editor_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size())); - input_value_info = nullptr; // graph now owns the input/output values - output_value_info = nullptr; - - // - // Gemm node - // - - OrtOpAttr* alpha_attr = nullptr; - float alpha_value = 2.0; - Ort::ThrowOnError(api.CreateOpAttr("alpha", &alpha_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT, &alpha_attr)); - - std::vector node_input_names = {"X", "Y"}; - const std::string gemm_output_name = use_constant_node ? "Z_temp" : "Z"; - std::vector node_output_names = {gemm_output_name.c_str()}; - std::vector node_attributes{alpha_attr}; - OrtNode* node = CreateNode(model_editor_api, "Gemm", "Gemm1", node_input_names, node_output_names, node_attributes); - alpha_attr = nullptr; // Node now owns - - Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); - node = nullptr; // graph now owns node - - // Y input - // As it's 128 bytes it could either be allocated using CreateTensorAsOrtValue or use existing memory. - // Under 128 bytes must use CreateTensorAsOrtValue. - std::vector y_dims = {4, 8}; - - deleter.weights.emplace_back(std::make_unique>(32)); - auto& y_values = *deleter.weights.back(); - std::iota(y_values.begin(), y_values.end(), 1.0f); - - // create an initializer for the Y input. add to `weights` so the memory remains valid. - OrtValue* y_tensor = nullptr; - Ort::ThrowOnError( - api.CreateTensorWithDataAndDeleterAsOrtValue(&deleter, - y_values.data(), y_values.size() * sizeof(y_values[0]), - y_dims.data(), y_dims.size(), - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - &y_tensor)); - - Ort::ThrowOnError(model_editor_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true)); - y_tensor = nullptr; // graph now owns - - if (use_constant_node) { - // Test that a Constant node is converted to an initializer - - // create Constant nodes for min/max to limit output range - OrtOpAttr* min_attr = nullptr; - float min = 400.0f; - Ort::ThrowOnError(api.CreateOpAttr("value", &min, sizeof(min), ORT_OP_ATTR_FLOAT, &min_attr)); - node = CreateNode(model_editor_api, "Constant", "clip_min", {}, {"min"}, {min_attr}); - Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); - node = nullptr; // graph now owns node - - OrtOpAttr* max_attr = nullptr; - float max = 900.0f; - Ort::ThrowOnError(api.CreateOpAttr("value", &max, sizeof(max), ORT_OP_ATTR_FLOAT, &max_attr)); - node = CreateNode(model_editor_api, "Constant", "clip_max", {}, {"max"}, {max_attr}); - Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); - node = nullptr; // graph now owns node - - node = CreateNode(model_editor_api, "Clip", "Clip1", {gemm_output_name.c_str(), "min", "max"}, {"Z"}); - Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); - node = nullptr; // graph now owns node - } - - std::vector domain_names = {onnxruntime::kOnnxDomain}; - std::vector opset_versions = {18}; - Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), domain_names.size(), - &model)); - Ort::ThrowOnError(model_editor_api.AddGraphToModel(model, graph)); - graph = nullptr; // model now owns - }; - - auto run_test = [&](bool use_constant_node) -> void { - OrtModel* model = nullptr; - build_model(use_constant_node, model); - - ASSERT_NE(model, nullptr) << "build_model should have created a model"; - - std::vector> inputs(1); - auto& input = inputs[0]; - input.name = "X"; - input.dims = {3, 4}; - input.values = {1.0f, 2.0f, 3.0f, 4.0f, - 8.0f, 7.0f, 6.0f, 5.0f, - 9.0f, 3.0f, 5.0f, 7.0f}; - - std::vector expected_dims = {3, 8}; - Model cxx_model(model); - auto session = CreateSession(*ort_env, cxx_model); - - std::vector expected_output; - if (use_constant_node) { - // clipped with min 400 and max 900 - expected_output = {400.0f, 400.0f, 400.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, - 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 900.0f, 900.0f, - 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 900.0f}; - } else { - expected_output = {340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, - 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f, - 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 928.0f}; - } - - TestInference(session, inputs, "Z", expected_dims, expected_output); - - api.ReleaseSession(session.release()); - - ASSERT_EQ(deleter.weights.size(), size_t(0)) << "All weights should have been freed"; - }; - - run_test(false); - run_test(true); // use Constant node for initializer -} - -TEST(ModelEditorAPITest, Basic_CxxApi) { - // initializers that are used directly by the model. as there's no copy they must remain valid - std::vector>> weights; - - Ort::Graph graph; - - // - // Create OrtModel with a Gemm. X input is 3x4, Y input is 4x8, Z output is 3x8. - // X is model input. Y is initializer. - // Set the alpha attribute of the Gemm node to 2.0 to test attribute handling. - // - - std::vector graph_inputs; - std::vector graph_outputs; - - // model input. it's {3, 4} but use a symbolic dim to test that works. - std::vector input_dims({-1, 4}); - std::vector input_symbolic_dims({"multiple_of_3", ""}); - TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - input_dims, - &input_symbolic_dims); - auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); - graph_inputs.emplace_back("X", input_type_info.GetConst()); - - // model outputs - std::vector output_dims = {-1, 8}; - std::vector output_symbolic_dims({"multiple_of_3", ""}); - TensorTypeAndShapeInfo output_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - output_dims, - &output_symbolic_dims); - auto output_type_info = TypeInfo::CreateTensorInfo(output_tensor_info.GetConst()); - graph_outputs.emplace_back("Z", output_type_info.GetConst()); - - graph.SetInputs(graph_inputs); - graph.SetOutputs(graph_outputs); - - // - // Gemm node - // - - std::vector attributes; - float alpha_value = 2.0; - attributes.push_back(OpAttr("alpha", &alpha_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT)); - - Node node("Gemm", onnxruntime::kOnnxDomain, "Gemm1", {"X", "Y"}, {"Z"}, attributes); - - graph.AddNode(node); - - // create an initializer for the Y input. - // add to `weights` so it remains valid for the lifetime of the session and we can avoid copying the data. - // As it's 128 bytes it could either be allocated using CreateTensorAsOrtValue or use existing memory. - // Under 128 bytes must use CreateTensorAsOrtValue. - std::vector y_dims = {4, 8}; - - weights.emplace_back(std::make_unique>(32)); - auto& y_values = *weights.back(); - std::iota(y_values.begin(), y_values.end(), 1.0f); - - auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - // if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession - auto y_tensor = Value::CreateTensor(info, y_values.data(), y_values.size(), y_dims.data(), y_dims.size()); - graph.AddInitializer("Y", y_tensor, /*data is external*/ true); - - std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; - Model model(opsets); - model.AddGraph(graph); - - std::vector> inputs(1); - auto& input = inputs[0]; - input.name = "X"; - input.dims = {3, 4}; - input.values = {1.0f, 2.0f, 3.0f, 4.0f, - 8.0f, 7.0f, 6.0f, 5.0f, - 9.0f, 3.0f, 5.0f, 7.0f}; - - std::vector expected_dims = {3, 8}; - - auto session = CreateSession(*ort_env, model); - TestInference(session, inputs, "Z", expected_dims, - {340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, - 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f, - 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 928.0f}); -} - -TEST(ModelEditorAPITest, BasicModelEdit_CxxApi) { - // - // Load existing model - // Add Cast to change the model input from float to int64 - // Update model inputs to match - // Run - // - - SessionOptions so; - - // Set this to save the model if you want to debug. - // so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx")); - - Session session = Session::CreateModelEditorSession(*ort_env, TSTR("testdata/mnist.onnx"), so); - - ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string - - // we augment the original model with nodes, initializers and the updated model inputs/outputs from this model. - // the original graph is unchanged. nodes can be added before/after it. initializers can be added. - // new nodes must conform to the original domain:opset of the model. - // additional operator domain:opset pairs can be added. - std::vector opsets; // no additional opsets required - Model model(opsets); - - std::vector graph_inputs = session.GetInputs(); - ASSERT_EQ(graph_inputs.size(), size_t(1)); - ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(), - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - - // typically this isn't needed. we replace this input but need to read info from it later on in the test - // validation so we save the info locally to keep it accessible. - auto orig_input_name = graph_inputs[0].Name(); - auto input_shape = graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape(); - const std::string new_input_name = "Int64Input"; - - // Add Cast node to convert input from float to int64 - std::vector attributes; - int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT)); - - Ort::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, - // the existing node will now consume the output from the Cast instead of a graph input - {orig_input_name}, - attributes); - - // we're replacing the only input. the shape is the same but the name and data type change. - TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, - input_shape); - auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); - graph_inputs[0] = ValueInfo(new_input_name, input_type_info.GetConst()); - - Graph graph; // new info to augment the model with - - graph.AddNode(node); - graph.SetInputs(graph_inputs); - - // the node we added does not require any new opsets. - model.AddGraph(graph); - session.FinalizeModelEditorSession(model, so); - - std::vector> inputs(1); - auto& input = inputs[0]; - input.name = new_input_name.c_str(); - input.dims = input_shape; - - auto num_values = std::accumulate(input.dims.begin(), input.dims.end(), int64_t(1), std::multiplies()); - input.values.resize(size_t(num_values)); - std::iota(input.values.begin(), input.values.end(), 1); - - std::vector expected_dims = {1, 10}; - std::vector expected_output = {-48.5088f, -1040.2948f, -347.0959f, 101.7392f, 421.3352f, - 750.92145f, 231.5060f, -1694.4152f, 681.5623f, 378.1689f}; - - TestInference(session, inputs, session.GetOutputNames()[0].c_str(), expected_dims, expected_output); - - // double check with original model - { - SessionOptions expected_so; - Session expected_session = Session(*ort_env, TSTR("testdata/mnist.onnx"), expected_so); - std::vector> expected_inputs(1); - auto& expected_input = expected_inputs[0]; - expected_input.name = orig_input_name.c_str(); - expected_input.dims = input_shape; - expected_input.values.reserve(size_t(num_values)); - std::transform(input.values.begin(), input.values.end(), std::back_inserter(expected_input.values), - [&](int64_t value) { return float(value); }); - - TestInference(expected_session, expected_inputs, session.GetOutputNames()[0].c_str(), - expected_dims, expected_output); - } -} - -TEST(ModelEditorAPITest, InvalidDimension) { - try { - std::vector input_dims = {-2, 2}; - TensorTypeAndShapeInfo tensor_type_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - input_dims); - // invalid dim of -2 should cause exception - TypeInfo::CreateTensorInfo(tensor_type_info.GetConst()); - FAIL() << "Expected exception for invalid dimension"; - } catch (const Ort::Exception& e) { - ASSERT_STREQ(e.what(), "dim_values must be -1 (symbolic dimension) or larger."); - } -} - -TEST(ModelEditorAPITest, CreateInvalidModel_NoOpsets) { - Ort::Graph graph; - std::vector graph_inputs; - std::vector graph_outputs; - - std::vector dims({4}); - TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); - auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); - graph_inputs.emplace_back("X", type_info.GetConst()); - graph_outputs.emplace_back("Z", type_info.GetConst()); - - graph.SetInputs(graph_inputs); - graph.SetOutputs(graph_outputs); - - Ort::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "X"}, {"Z"}); - - graph.AddNode(node); - - std::vector opsets; - Model model(opsets); - model.AddGraph(graph); - - try { - auto session = CreateSession(*ort_env, model); - FAIL(); - } catch (const Ort::Exception& e) { - ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain")); - } -} - -TEST(ModelEditorAPITest, CreateInvalidModel_MissingValue) { - Ort::Graph graph; - - std::vector graph_inputs; - std::vector graph_outputs; - - std::vector dims({4}); - TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); - auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); - graph_inputs.emplace_back("X", type_info.GetConst()); - graph_outputs.emplace_back("Z", type_info.GetConst()); - - graph.SetInputs(graph_inputs); - graph.SetOutputs(graph_outputs); - - Ort::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "missing"}, {"Z"}); - graph.AddNode(node); - - std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; - Model model(opsets); - model.AddGraph(graph); - - try { - auto session = CreateSession(*ort_env, model); - FAIL(); - } catch (const Ort::Exception& e) { - ASSERT_THAT(e.what(), ::testing::HasSubstr("Node input 'missing' is not a graph input, " - "initializer, or output of a previous node.")); - } -} - -TEST(ModelEditorAPITest, InvalidModelEdit) { - // Add a node but make the edit invalid in various ways - // - add node but don't update graph inputs - // - add node with invalid domain - const auto edit_model = [](bool invalid_domain) { - SessionOptions so; - - // Set this to save the model if you want to debug. - // so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx")); - - Session session = Session::CreateModelEditorSession(*ort_env, TSTR("testdata/mnist.onnx"), so); - - ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string - - std::vector opsets; // no additional opsets required - Model model(opsets); - Graph graph; // new info to augment the model with - - const char* domain = invalid_domain ? "invalid_domain" : onnxruntime::kOnnxDomain; - - std::vector graph_inputs = session.GetInputs(); - ASSERT_EQ(graph_inputs.size(), size_t(1)); - ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(), - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - - const std::string new_input_name = "Int64Input"; - - // Add Cast node to convert input from float to int64 - std::vector attributes; - int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT)); - - Node node("Cast", domain, "NewInputNode", {new_input_name}, - // the existing node will now consume the output from the Cast instead of a graph input - {graph_inputs[0].Name()}, - attributes); - graph.AddNode(node); - - if (invalid_domain) { - // we're replacing the only input. the shape is the same but the name and data type change. - TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, - graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape()); - auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); - graph_inputs[0] = ValueInfo(new_input_name, input_type_info.GetConst()); - graph.SetInputs(graph_inputs); - } else { - // model should be invalid as we didn't connect the new node up to the graph inputs - } - - // the node we added does not require any new opsets. - model.AddGraph(graph); - - try { - session.FinalizeModelEditorSession(model, so); - FAIL() << "Should have failed to resolve graph due to invalid edits."; - } catch (const Ort::Exception& e) { - if (invalid_domain) { - ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain 'invalid_domain'")); - } else { - ASSERT_THAT(e.what(), ::testing::HasSubstr("This is an invalid model")); - } - } - }; - - edit_model(false); - edit_model(true); // add node with invalid domain -} - -TEST(ModelEditorAPITest, CreateTypeInfo) { - const auto& api = Ort::GetApi(); - const auto& model_editor_api = Ort::GetModelEditorApi(); - - TensorTypeAndShapeInfo base_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - {2, 4}); - - OrtTypeInfo* base_tensor_type_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(base_tensor_info, &base_tensor_type_info)); - - ONNXType onnx_type = ONNX_TYPE_UNKNOWN; - const OrtTensorTypeAndShapeInfo* tensor_info = nullptr; - ONNXTensorElementDataType onnx_element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - - // sparse tensor - OrtTypeInfo* sparse_tensor_type_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateSparseTensorTypeInfo(base_tensor_info, &sparse_tensor_type_info)); - Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sparse_tensor_type_info, &onnx_type)); - ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_SPARSETENSOR); - Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(sparse_tensor_type_info, &tensor_info)); - Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type)); - ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - api.ReleaseTypeInfo(sparse_tensor_type_info); - - // sequence - OrtTypeInfo* sequence_type_info = nullptr; - const OrtSequenceTypeInfo* sequence_info = nullptr; - OrtTypeInfo* sequence_element_type_info = nullptr; - - Ort::ThrowOnError(model_editor_api.CreateSequenceTypeInfo(base_tensor_type_info, &sequence_type_info)); - Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sequence_type_info, &onnx_type)); - ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_SEQUENCE); - Ort::ThrowOnError(api.CastTypeInfoToSequenceTypeInfo(sequence_type_info, &sequence_info)); - Ort::ThrowOnError(api.GetSequenceElementType(sequence_info, &sequence_element_type_info)); - Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sequence_element_type_info, &onnx_type)); - ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR); - Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(sequence_element_type_info, &tensor_info)); - Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type)); - ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - api.ReleaseTypeInfo(sequence_element_type_info); - api.ReleaseTypeInfo(sequence_type_info); - - // map - OrtTypeInfo* map_type_info = nullptr; - const OrtMapTypeInfo* map_info = nullptr; - ONNXTensorElementDataType map_key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - OrtTypeInfo* map_value_type_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateMapTypeInfo(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, base_tensor_type_info, - &map_type_info)); // clones map_type_info - Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(map_type_info, &onnx_type)); - ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_MAP); - Ort::ThrowOnError(api.CastTypeInfoToMapTypeInfo(map_type_info, &map_info)); - Ort::ThrowOnError(api.GetMapKeyType(map_info, &map_key_type)); - ASSERT_EQ(map_key_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); - Ort::ThrowOnError(api.GetMapValueType(map_info, &map_value_type_info)); - Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(map_value_type_info, &onnx_type)); - ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR); - Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(map_value_type_info, &tensor_info)); - Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type)); - ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - api.ReleaseTypeInfo(map_value_type_info); - api.ReleaseTypeInfo(map_type_info); - - // optional - OrtTypeInfo* optional_type_info = nullptr; - const OrtOptionalTypeInfo* optional_info = nullptr; - OrtTypeInfo* optional_contained_type_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateOptionalTypeInfo(base_tensor_type_info, &optional_type_info)); - Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(optional_type_info, &onnx_type)); - ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_OPTIONAL); - Ort::ThrowOnError(api.CastTypeInfoToOptionalTypeInfo(optional_type_info, &optional_info)); - Ort::ThrowOnError(api.GetOptionalContainedTypeInfo(optional_info, &optional_contained_type_info)); - Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(optional_contained_type_info, &onnx_type)); - ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR); - api.ReleaseTypeInfo(optional_contained_type_info); - api.ReleaseTypeInfo(optional_type_info); - - api.ReleaseTypeInfo(base_tensor_type_info); -} diff --git a/onnxruntime/test/shared_lib/test_ort_format_models.cc b/onnxruntime/test/shared_lib/test_ort_format_models.cc index b3491e3476f23..99a9ebc3362ae 100644 --- a/onnxruntime/test/shared_lib/test_ort_format_models.cc +++ b/onnxruntime/test/shared_lib/test_ort_format_models.cc @@ -17,7 +17,7 @@ extern std::unique_ptr ort_env; [[maybe_unused]] static void TestInference(Ort::Env& env, const std::basic_string& model_uri, - const std::vector>& inputs, const char* output_name, + const std::vector& inputs, const char* output_name, const std::vector& expected_dims_y, const std::vector& expected_values_y, Ort::CustomOpDomain& custom_op_domain, void* cuda_compute_stream = nullptr) { Ort::SessionOptions session_options; @@ -100,8 +100,8 @@ TEST(OrtFormatCustomOpTests, ConvertOnnxModelToOrt) { } // now load the ORT format model and execute it - std::vector> inputs(1); - auto& input = inputs[0]; + std::vector inputs(1); + Input& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -130,8 +130,8 @@ TEST(OrtFormatCustomOpTests, LoadOrtModel) { custom_op_domain.Add(&custom_op); // load the ORT format model and execute it - std::vector> inputs(1); - auto& input = inputs[0]; + std::vector inputs(1); + Input& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; @@ -151,8 +151,8 @@ TEST(OrtFormatCustomOpTests, LoadOrtModelStandaloneCustomOpImplementation) { custom_op_domain.Add(&standalone_op); // load the ORT format model and execute it - std::vector> inputs(1); - auto& input = inputs[0]; + std::vector inputs(1); + Input& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; diff --git a/onnxruntime/test/shared_lib/utils.h b/onnxruntime/test/shared_lib/utils.h index 5d15582b86cb9..483753f2ae6b2 100644 --- a/onnxruntime/test/shared_lib/utils.h +++ b/onnxruntime/test/shared_lib/utils.h @@ -5,56 +5,4 @@ #include "core/session/onnxruntime_cxx_api.h" -#include "gtest/gtest.h" - OrtCUDAProviderOptions CreateDefaultOrtCudaProviderOptionsWithCustomStream(void* cuda_compute_stream = nullptr); - -template -struct Input { - const char* name = nullptr; - std::vector dims; - std::vector values; -}; - -template > -void RunSession(OrtAllocator* allocator, - Ort::Session& session_object, - const std::vector& inputs, - const char* output_name, - const std::vector& output_dims, - const std::vector& expected_output, - Ort::Value* output_tensor) { - std::vector ort_inputs; - std::vector input_names; - for (size_t i = 0; i < inputs.size(); i++) { - input_names.emplace_back(inputs[i].name); - ort_inputs.emplace_back( - Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), - inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); - } - - std::vector ort_outputs; - if (output_tensor) - session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), - &output_name, output_tensor, 1); - else { - ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), - &output_name, 1); - ASSERT_EQ(ort_outputs.size(), 1u); - output_tensor = &ort_outputs[0]; - } - - auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); - ASSERT_EQ(type_info.GetShape(), output_dims); - size_t total_len = type_info.GetElementCount(); - ASSERT_EQ(expected_output.size(), total_len); - - auto* actual = output_tensor->GetTensorMutableData(); - for (size_t i = 0; i != total_len; ++i) { - if constexpr (std::is_same::value || std::is_same::value) { - EXPECT_NEAR(expected_output[i], actual[i], 1e-3) << "i=" << i; - } else { - EXPECT_EQ(expected_output[i], actual[i]) << "i=" << i; - } - } -} diff --git a/onnxruntime/test/testdata/cast_float_to_double.onnx b/onnxruntime/test/testdata/cast_float_to_double.onnx deleted file mode 100644 index dc7997cddd8a8c762e354316662fb0d734e25e86..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 136 zcmdfpOwLZtOVKS!EiSPt;8NgX&CDw(EfHeNFD(JmN-WNa#U)ytTudeT65I-kD&v!ZqVaA%{*EE>CHe6#{-I7ju2JGJ&3s%u9E?I7TudCyK+KXP!38x=2qeRe Mka1$+Vh|7o0L&R4`v3p{ diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc index 27a4b06a99e64..57471f7c029c2 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Confidential and Proprietary. #include "my_execution_provider.h" #include "my_allocator.h" diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h index efb359a9e5e43..ff0c7e80c4eeb 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Confidential and Proprietary. #pragma once diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 1ad35b51bb1c1..7adfc6a2b2ccb 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -8,14 +8,6 @@ #include "core/session/onnxruntime_cxx_api.h" #include "api.h" -#ifdef USE_WEBGPU -namespace onnxruntime { -namespace webgpu { -WGPUDevice GetDevice(int); -} -} // namespace onnxruntime -#endif - #include #include #include @@ -172,12 +164,8 @@ OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, return UNREGISTER_AUTO_RELEASE(session_options); } -int OrtAppendExecutionProvider(ort_session_options_handle_t session_options, - const char* name, - const char* const* provider_options_keys, - const char* const* provider_options_values, - size_t num_keys) { - return CHECK_STATUS(SessionOptionsAppendExecutionProvider, session_options, name, provider_options_keys, provider_options_values, num_keys); +int OrtAppendExecutionProvider(ort_session_options_handle_t session_options, const char* name) { + return CHECK_STATUS(SessionOptionsAppendExecutionProvider, session_options, name, nullptr, nullptr, 0); } int OrtAddFreeDimensionOverride(ort_session_options_handle_t session_options, @@ -519,16 +507,6 @@ char* OrtEndProfiling(ort_session_handle_t session) { : nullptr; } -// WebGPU API Section - -#ifdef USE_WEBGPU - -WGPUDevice OrtGetWebGpuDevice(int device_id) { - return onnxruntime::webgpu::GetDevice(device_id); -} - -#endif - // Training API Section #ifdef ENABLE_TRAINING_APIS diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 9ff1eb55ecedc..f44c515d98f6b 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -10,10 +10,6 @@ #include -#ifdef USE_WEBGPU -#include -#endif - #include struct OrtSession; @@ -89,10 +85,7 @@ ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtAppendExecutionProvider(ort_session_options_handle_t session_options, - const char* name, - const char* const* provider_options_keys, - const char* const* provider_options_values, - size_t num_keys); + const char* name); /** * add a free dimension override for one dimension of a session's input. @@ -301,21 +294,6 @@ int EMSCRIPTEN_KEEPALIVE OrtRun(ort_session_handle_t session, */ char* EMSCRIPTEN_KEEPALIVE OrtEndProfiling(ort_session_handle_t session); -// WebGPU API Section - -#ifdef USE_WEBGPU - -/** - * get the GPU Device by device ID. - * - * This function is only available after the GPU Device is initialized in WebGpuContextFactory. - * - * @returns a WGPUDevice handle. - */ -WGPUDevice EMSCRIPTEN_KEEPALIVE OrtGetWebGpuDevice(int device_id); - -#endif - // Training API Section #ifdef ENABLE_TRAINING_APIS diff --git a/onnxruntime/wasm/js_post_js.js b/onnxruntime/wasm/js_post_js.js index 56d3246fd07f0..b77d82fbd7d10 100644 --- a/onnxruntime/wasm/js_post_js.js +++ b/onnxruntime/wasm/js_post_js.js @@ -2,4 +2,6 @@ // Licensed under the MIT License. +'use strict'; + Module["PTR_SIZE"] = 4; diff --git a/onnxruntime/wasm/js_post_js_64.js b/onnxruntime/wasm/js_post_js_64.js index cfd79523f7900..b140df927ebbd 100644 --- a/onnxruntime/wasm/js_post_js_64.js +++ b/onnxruntime/wasm/js_post_js_64.js @@ -2,4 +2,6 @@ // Licensed under the MIT License. +'use strict'; + Module["PTR_SIZE"] = 8; diff --git a/onnxruntime/wasm/post-webgpu.js b/onnxruntime/wasm/post-webgpu.js deleted file mode 100644 index 146355f6a44d3..0000000000000 --- a/onnxruntime/wasm/post-webgpu.js +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// -// This file contains the post-run code for the ORT WebAssembly module. The code in this file will be injected into the -// final module using Emscripten's `--post-js` option. -// -// This file will only be used in build with flag `--use_webgpu`. - -/** - * This function is called only once when initializing the WebGPU backend. - * - * @param {(gpuDevice: GPUDevice) => void} setDefaultDevice A callback function to set the default device. - */ -Module["webgpuInit"] = (setDefaultDevice) => { - /** - * a map from GPUDevice to [deviceId, instanceHandle, deviceHandle] - * - * only stores custom devices (ie. devices created by the user, not the default device created by ORT) - * - * key is the GPUDevice object. - * - * value is a tuple of 3 elements: - * - deviceId: a unique ID for the device. Must be positive integer. - * - instanceHandle: the instance handle(pointer) of the device. - * - deviceHandle: the device handle(pointer) of the device. - * - * @type {WeakMap} - */ - const webgpuActiveDevices = new WeakMap(); - /** - * a number that is used to assign a unique ID to the next custom device. - */ - let webgpuNextDeviceId = 1; - /** - * a function to set the default device. - * - * @type {(gpuDevice: GPUDevice) => void} - */ - const webgpuSetDefaultDevice = setDefaultDevice; - /** - * the current device that is being used to create a WebGPU EP inference session. - * - * the value of this variable is only valid during the creation of a WebGPU EP inference session. - * - * @type {GPUDevice|undefined} - */ - let webgpuCurrentDevice = undefined; - /** - * the current device ID that is being used to create a WebGPU EP inference session. - * - * the value of this variable is only valid during the creation of a WebGPU EP inference session. - * - * @type {number|undefined} - */ - let webgpuCurrentDeviceId = undefined; - - /** - * This function is called only when a custom device is used, during preparation of session options. - * - * @param {GPUDevice} device the user provided device object. - * @returns {undefined|[number, number, number]} a tuple of device id, instance handle, and device handle. - */ - Module["webgpuRegisterDevice"] = (device) => { - if (webgpuCurrentDeviceId !== undefined) { - throw new Error("another WebGPU EP inference session is being created."); - } - - if (device) { - let deviceInfo = webgpuActiveDevices.get(device); - if (!deviceInfo) { - const instanceHandle = _wgpuCreateInstance(0); - const deviceHandle = WebGPU.importJsDevice(device, instanceHandle); - deviceInfo = [webgpuNextDeviceId++, instanceHandle, deviceHandle]; - webgpuActiveDevices.set(device, deviceInfo); - } - - // The current device ID is a temporary storage for the device ID to be used in the session that is being created. - // - // Soon after `webgpuRegisterDevice` (this function) is called, `webgpuOnCreateSession` will be called so that the - // value of `webgpuCurrentDeviceId` is used and reset then. - webgpuCurrentDevice = device; - webgpuCurrentDeviceId = deviceInfo[0]; - return deviceInfo; - } else { - webgpuCurrentDevice = undefined; - webgpuCurrentDeviceId = 0; - return undefined; - } - }; - - const webgpuActiveSessions = new Map(); - Module["webgpuOnCreateSession"] = (sessionHandle) => { - if (webgpuCurrentDeviceId === undefined) { - // do nothing if webgpuCurrentDeviceId is undefined. - // this means no WebGPU EP is being created. - return; - } - - const deviceId = webgpuCurrentDeviceId; - webgpuCurrentDeviceId = undefined; - - if (sessionHandle) { - // when session created successfully - const deviceHandle = _OrtGetWebGpuDevice(deviceId); - webgpuActiveSessions.set(sessionHandle, deviceHandle); - - if (deviceId === 0) { - const device = webgpuCurrentDevice ?? WebGPU.getJsObject(deviceHandle); - webgpuSetDefaultDevice(device); - } - } - webgpuCurrentDevice = undefined; - }; - - Module["webgpuOnReleaseSession"] = (sessionHandle) => { - webgpuActiveSessions.delete(sessionHandle); - }; - - const gpuBufferMetadataSymbol = Symbol("gpuBufferMetadata"); - - Module["webgpuRegisterBuffer"] = (buffer, sessionHandle, bufferHandle) => { - if (bufferHandle) { - // This is a buffer that was created by ORT. Metadata is [bufferHandle, NaN] - - buffer[gpuBufferMetadataSymbol] = [bufferHandle, NaN]; - return bufferHandle; - } else { - // This is a buffer that was created by the user. Metadata is [bufferHandle, refCount] - - const metadata = buffer[gpuBufferMetadataSymbol]; - if (metadata) { - metadata[1]++; - return metadata[0]; - } - - const deviceHandle = webgpuActiveSessions.get(sessionHandle); - if (deviceHandle === undefined) { - throw new Error( - "Invalid session handle passed to webgpuRegisterBuffer" - ); - } - - const bufferHandle = WebGPU.importJsBuffer(buffer, deviceHandle); - buffer[gpuBufferMetadataSymbol] = [bufferHandle, 1]; - return bufferHandle; - } - }; - - Module["webgpuUnregisterBuffer"] = (buffer) => { - const metadata = buffer[gpuBufferMetadataSymbol]; - if (!metadata) { - throw new Error("Buffer is not registered"); - } - metadata[1]--; - // For buffers created by ORT, metadata[1] will always be NaN. This function will not release the buffer. - // Instead, the buffer will be released when user calls `Tensor.dispose()` in JavaScript. - if (metadata[1] === 0) { - _wgpuBufferRelease(metadata[0]); - delete buffer[gpuBufferMetadataSymbol]; - } - }; - - Module["webgpuGetBuffer"] = (bufferHandle) => { - return WebGPU.getJsObject(bufferHandle); - }; - - Module["webgpuCreateDownloader"] = (gpuBuffer, bufferSize, sessionHandle) => { - const deviceHandle = webgpuActiveSessions.get(sessionHandle); - if (deviceHandle === undefined) { - throw new Error("Invalid session handle passed to webgpuRegisterBuffer"); - } - - const buffer = gpuBuffer; - const device = WebGPU.getJsObject(deviceHandle); - const originalSize = bufferSize; - const size = Math.ceil(Number(originalSize) / 16) * 16; - - return async () => { - // prettier-ignore - // - // the line above is used to force prettier to skip formatting the next statement. - // this is because prettier will remove the quotes around the property names, but we need to keep them - // because otherwise closure compiler may rename them and break the code. - const gpuReadBufferDescriptor = { - "size": size, - "usage": 9 /* GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ */, - }; - const gpuReadBuffer = device.createBuffer(gpuReadBufferDescriptor); - try { - const commandEncoder = device.createCommandEncoder(); - commandEncoder.copyBufferToBuffer( - buffer /* source buffer */, - 0 /* source offset */, - gpuReadBuffer /* destination buffer */, - 0 /* destination offset */, - size /* size */ - ); - device.queue.submit([commandEncoder.finish()]); - - await gpuReadBuffer.mapAsync(GPUMapMode.READ); - - const arrayBuffer = gpuReadBuffer.getMappedRange(); - return arrayBuffer.slice(0, originalSize); - } finally { - gpuReadBuffer.destroy(); - } - }; - }; - - // Setup a callback function for loading external buffers (model weights). - Module.webgpuUploadExternalBuffer = (bufferHandle, data) => { - const srcArrayBuffer = data.buffer; - const srcOffset = data.byteOffset; - const srcLength = data.byteLength; - const size = Math.ceil(Number(srcLength) / 16) * 16; - - const gpuBuffer = WebGPU.getJsObject(bufferHandle); - - // get current device - if (!webgpuCurrentDevice) { - const deviceHandle = _OrtGetWebGpuDevice(webgpuCurrentDeviceId); - webgpuCurrentDevice = WebGPU.getJsObject(deviceHandle); - } - - // create gpu buffer - - // prettier-ignore - // - // the line above is used to force prettier to skip formatting the next statement. - // this is because prettier will remove the quotes around the property names, but we need to keep them - // because otherwise closure compiler may rename them and break the code. - const gpuBufferForUploadingDescriptor = { - "mappedAtCreation": true, - "size": size, - "usage": 6 /* GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC */, - }; - const gpuBufferForUploading = webgpuCurrentDevice.createBuffer( - gpuBufferForUploadingDescriptor - ); - - // copy (upload) data - const arrayBuffer = gpuBufferForUploading.getMappedRange(); - new Uint8Array(arrayBuffer).set( - new Uint8Array(srcArrayBuffer, srcOffset, srcLength) - ); - gpuBufferForUploading.unmap(); - - // GPU copy - const commandEncoder = webgpuCurrentDevice.createCommandEncoder(); - commandEncoder.copyBufferToBuffer( - gpuBufferForUploading, - 0, - gpuBuffer, - 0, - size - ); - webgpuCurrentDevice.queue.submit([commandEncoder.finish()]); - gpuBufferForUploading.destroy(); - }; -}; diff --git a/onnxruntime/wasm/pre-async.js b/onnxruntime/wasm/pre-async.js deleted file mode 100644 index 8c75dc7c5cf1e..0000000000000 --- a/onnxruntime/wasm/pre-async.js +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// -// This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the -// final module using Emscripten's `--pre-js` option. -// -// This file will only be used in build with flag `-s ASYNCIFY=1`. - -/** - * initialize for asyncify support. - */ -let initAsyncImpl = () => { - // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) - // It removes some overhead in cwarp() and ccall() that we don't need. - // - // Currently in ASYNCIFY build, we only use this for the following functions: - // - OrtCreateSession() - // - OrtRun() - // - OrtRunWithBinding() - // - OrtBindInput() - // - // Note: about parameters "getFunc" and "setFunc": - // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. - // - // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a - // wrapper for OrtRun() like this (minified): - // ``` - // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); - // ``` - // - // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates - // a wrapper for OrtRun() like this (minified): - // ``` - // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // - // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once - // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will - // reset d._OrtRun to J.ka when the first time it is called. - // - // The difference is important because we need to design the async wrapper in a way that it can handle both cases. - // - // Now, let's look at how the async wrapper is designed to work for both cases: - // - // - Debug build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. - // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // - Release build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. - // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this - // function: - // ``` - // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). - // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored - // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release - // build. - // - // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an - // exported function and set the new value of an exported function. - // - const wrapAsync = (func, getFunc, setFunc) => { - return (...args) => { - // cache the async data before calling the function. - const previousAsync = Asyncify.currData; - - const previousFunc = getFunc?.(); - const ret = func(...args); - const newFunc = getFunc?.(); - if (previousFunc !== newFunc) { - // The exported function has been updated. - // Set the sync function reference to the new function. - func = newFunc; - // Set the exported function back to the async wrapper. - setFunc(previousFunc); - // Remove getFunc and setFunc. They are no longer needed. - setFunc = null; - getFunc = null; - } - - // If the async data has been changed, it means that the function started an async operation. - if (Asyncify.currData != previousAsync) { - // returns the promise - return Asyncify.whenDone(); - } - // the function is synchronous. returns the result. - return ret; - }; - }; - - // replace the original functions with asyncified versions - const wrapAsyncAPIs = (funcNames) => { - for (const funcName of funcNames) { - Module[funcName] = wrapAsync( - Module[funcName], - () => Module[funcName], - (v) => (Module[funcName] = v) - ); - } - }; - - wrapAsyncAPIs([ - "_OrtAppendExecutionProvider", - "_OrtCreateSession", - "_OrtRun", - "_OrtRunWithBinding", - "_OrtBindInput", - ]); - - // If JSEP is enabled, wrap OrtRun() and OrtRunWithBinding() with asyncify. - if (typeof jsepRunAsync !== "undefined") { - Module["_OrtRun"] = jsepRunAsync(Module["_OrtRun"]); - Module["_OrtRunWithBinding"] = jsepRunAsync(Module["_OrtRunWithBinding"]); - } - - // remove this function to make sure it is called only once. - initAsyncImpl = undefined; -}; - -Module["asyncInit"] = () => { - initAsyncImpl?.(); -}; diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 5b2f044d4c27b..0c83e71a921cb 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -1,157 +1,255 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +'use strict'; + // // This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the // final module using Emscripten's `--pre-js` option. // // This file will only be used in build with flag `--use_jsep`. -// This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly. -const jsepRunAsync = (runAsyncFunc) => { - return async (...args) => { - try { - // Module.jsepSessionState should be null, unless we are in the middle of a session. - // If it is not null, it means that the previous session has not finished yet. - if (Module.jsepSessionState) { - throw new Error("Session already started"); - } - const state = (Module.jsepSessionState = { - sessionHandle: args[0], - errors: [], - }); - // Run the acyncified function: OrtRun() or OrtRunWithBinding() - const ret = await runAsyncFunc(...args); +/** + * initialize JSEP for asyncify support. + */ +let jsepInitAsync = () => { + // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) + // It removes some overhead in cwarp() and ccall() that we don't need. + // + // Currently in JSEP build, we only use this for the following functions: + // - OrtRun() + // - OrtRunWithBinding() + // - OrtBindInput() + // + // Note: about parameters "getFunc" and "setFunc": + // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. + // + // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a + // wrapper for OrtRun() like this (minified): + // ``` + // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); + // ``` + // + // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates + // a wrapper for OrtRun() like this (minified): + // ``` + // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // + // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once + // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will + // reset d._OrtRun to J.ka when the first time it is called. + // + // The difference is important because we need to design the async wrapper in a way that it can handle both cases. + // + // Now, let's look at how the async wrapper is designed to work for both cases: + // + // - Debug build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. + // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // - Release build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. + // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this + // function: + // ``` + // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). + // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored + // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release + // build. + // + // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an + // exported function and set the new value of an exported function. + // + const jsepWrapAsync = (func, getFunc, setFunc) => { + return (...args) => { + // cache the async data before calling the function. + const previousAsync = Asyncify.currData; - // Check if the session is still valid. this object should be the same as the one we set above. - if (Module.jsepSessionState !== state) { - throw new Error("Session mismatch"); + const previousFunc = getFunc?.(); + const ret = func(...args); + const newFunc = getFunc?.(); + if (previousFunc !== newFunc) { + // The exported function has been updated. + // Set the sync function reference to the new function. + func = newFunc; + // Set the exported function back to the async wrapper. + setFunc(previousFunc); + // Remove getFunc and setFunc. They are no longer needed. + setFunc = null; + getFunc = null; } - // Flush the backend. This will submit all pending commands to the GPU. - Module.jsepBackend?.["flush"](); + // If the async data has been changed, it means that the function started an async operation. + if (Asyncify.currData != previousAsync) { + // returns the promise + return Asyncify.whenDone(); + } + // the function is synchronous. returns the result. + return ret; + }; + }; - // Await all pending promises. This includes GPU validation promises for diagnostic purposes. - const errorPromises = state.errors; - if (errorPromises.length > 0) { - let errors = await Promise.all(errorPromises); - errors = errors.filter((e) => e); - if (errors.length > 0) { - throw new Error(errors.join("\n")); + // This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly. + const runAsync = (runAsyncFunc) => { + return async (...args) => { + try { + // Module.jsepSessionState should be null, unless we are in the middle of a session. + // If it is not null, it means that the previous session has not finished yet. + if (Module.jsepSessionState) { + throw new Error('Session already started'); } - } + const state = Module.jsepSessionState = {sessionHandle: args[0], errors: []}; - return ret; - } finally { - Module.jsepSessionState = null; - } + // Run the acyncified function: OrtRun() or OrtRunWithBinding() + const ret = await runAsyncFunc(...args); + + // Check if the session is still valid. this object should be the same as the one we set above. + if (Module.jsepSessionState !== state) { + throw new Error('Session mismatch'); + } + + // Flush the backend. This will submit all pending commands to the GPU. + Module.jsepBackend?.['flush'](); + + // Await all pending promises. This includes GPU validation promises for diagnostic purposes. + const errorPromises = state.errors; + if (errorPromises.length > 0) { + let errors = await Promise.all(errorPromises); + errors = errors.filter(e => e); + if (errors.length > 0) { + throw new Error(errors.join('\n')); + } + } + + return ret; + } finally { + Module.jsepSessionState = null; + } + }; }; + + // replace the original functions with asyncified versions + Module['_OrtCreateSession'] = jsepWrapAsync( + Module['_OrtCreateSession'], + () => Module['_OrtCreateSession'], + v => Module['_OrtCreateSession'] = v); + Module['_OrtRun'] = runAsync(jsepWrapAsync( + Module['_OrtRun'], + () => Module['_OrtRun'], + v => Module['_OrtRun'] = v)); + Module['_OrtRunWithBinding'] = runAsync(jsepWrapAsync( + Module['_OrtRunWithBinding'], + () => Module['_OrtRunWithBinding'], + v => Module['_OrtRunWithBinding'] = v)); + Module['_OrtBindInput'] = jsepWrapAsync( + Module['_OrtBindInput'], + () => Module['_OrtBindInput'], + v => Module['_OrtBindInput'] = v); + + // remove this function to make sure it is called only once. + jsepInitAsync = undefined; }; + /** - * initialize JSEP for WebGPU and WebNN. + * initialize JSEP for WebGPU. */ -Module["jsepInit"] = (name, params) => { - if (name === "webgpu") { - [ - Module.jsepBackend, - Module.jsepAlloc, - Module.jsepFree, - Module.jsepCopy, - Module.jsepCopyAsync, - Module.jsepCreateKernel, - Module.jsepReleaseKernel, - Module.jsepRunKernel, - Module.jsepCaptureBegin, - Module.jsepCaptureEnd, - Module.jsepReplay, - ] = params; +Module['jsepInit'] = (name, params) => { + jsepInitAsync?.(); + + if (name === 'webgpu') { + [Module.jsepBackend, + Module.jsepAlloc, + Module.jsepFree, + Module.jsepCopy, + Module.jsepCopyAsync, + Module.jsepCreateKernel, + Module.jsepReleaseKernel, + Module.jsepRunKernel, + Module.jsepCaptureBegin, + Module.jsepCaptureEnd, + Module.jsepReplay] = params; // expose webgpu backend functions const backend = Module.jsepBackend; - Module["jsepRegisterBuffer"] = (sessionId, index, buffer, size) => { - return backend["registerBuffer"](sessionId, index, buffer, size); + Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { + return backend['registerBuffer'](sessionId, index, buffer, size); }; - Module["jsepGetBuffer"] = (dataId) => { - return backend["getBuffer"](dataId); + Module['jsepGetBuffer'] = (dataId) => { + return backend['getBuffer'](dataId); }; - Module["jsepCreateDownloader"] = (gpuBuffer, size, type) => { - return backend["createDownloader"](gpuBuffer, size, type); + Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { + return backend['createDownloader'](gpuBuffer, size, type); }; - Module["jsepOnCreateSession"] = (sessionId) => { - backend["onCreateSession"](sessionId); + Module['jsepOnCreateSession'] = sessionId => { + backend['onCreateSession'](sessionId); }; - Module["jsepOnReleaseSession"] = (sessionId) => { - backend["onReleaseSession"](sessionId); + Module['jsepOnReleaseSession'] = sessionId => { + backend['onReleaseSession'](sessionId); }; - Module["jsepOnRunStart"] = (sessionId) => { - return backend["onRunStart"](sessionId); + Module['jsepOnRunStart'] = sessionId => { + return backend['onRunStart'](sessionId); }; Module.jsepUploadExternalBuffer = (dataId, buffer) => { - backend["upload"](dataId, buffer); + backend['upload'](dataId, buffer); }; - } else if (name === "webnn") { + } else if (name === 'webnn') { // Functions called from EM_ASM need to be assigned in a way that can be minified. // Functions called via emscripten::val::module_property need to be assigned by name so that the minifier doesn't // change the name. - [ - Module.jsepBackend, - Module.jsepReserveTensorId, - Module.jsepReleaseTensorId, - Module["jsepEnsureTensor"], - Module.jsepUploadTensor, - Module["jsepDownloadTensor"], + [Module.jsepBackend, + Module.jsepReserveTensorId, + Module.jsepReleaseTensorId, + Module['jsepEnsureTensor'], + Module.jsepUploadTensor, + Module['jsepDownloadTensor'], ] = params; // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. - Module["jsepReleaseTensorId"] = Module.jsepReleaseTensorId; - Module["jsepUploadTensor"] = Module.jsepUploadTensor; + Module['jsepReleaseTensorId'] = Module.jsepReleaseTensorId; + Module['jsepUploadTensor'] = Module.jsepUploadTensor; // Functions called from JS also need to have explicit names. const backend = Module.jsepBackend; - Module["jsepOnRunStart"] = (sessionId) => { - return backend["onRunStart"](sessionId); - }; - Module["jsepOnRunEnd"] = backend["onRunEnd"].bind(backend); - Module["jsepRegisterMLContext"] = (sessionId, mlContext) => { - backend["registerMLContext"](sessionId, mlContext); + Module['jsepOnRunStart'] = sessionId => { + return backend['onRunStart'](sessionId); }; - Module["jsepOnReleaseSession"] = (sessionId) => { - backend["onReleaseSession"](sessionId); + Module['jsepOnRunEnd'] = backend['onRunEnd'].bind(backend); + Module['jsepRegisterMLContext'] = (sessionId, mlContext) => { + backend['registerMLContext'](sessionId, mlContext); }; - Module["jsepCreateMLTensorDownloader"] = (tensorId, type) => { - return backend["createMLTensorDownloader"](tensorId, type); + Module['jsepOnReleaseSession'] = sessionId => { + backend['onReleaseSession'](sessionId); }; - Module["jsepRegisterMLTensor"] = (sessionId, tensor, dataType, shape) => { - return backend["registerMLTensor"](sessionId, tensor, dataType, shape); + Module['jsepCreateMLTensorDownloader'] = (tensorId, type) => { + return backend['createMLTensorDownloader'](tensorId, type); + } + Module['jsepRegisterMLTensor'] = (sessionId, tensor, dataType, shape) => { + return backend['registerMLTensor'](sessionId, tensor, dataType, shape); }; - Module["jsepCreateMLContext"] = (optionsOrGpuDevice) => { - return backend["createMLContext"](optionsOrGpuDevice); + Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { + return backend['createMLContext'](optionsOrGpuDevice); }; - Module["jsepRegisterMLConstant"] = ( - externalFilePath, - dataOffset, - dataLength, - builder, - desc - ) => { - return backend["registerMLConstant"]( - externalFilePath, - dataOffset, - dataLength, - builder, - desc, - Module.MountedFiles - ); + Module['jsepRegisterMLConstant'] = (externalFilePath, dataOffset, dataLength, builder, desc) => { + return backend['registerMLConstant']( + externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); }; - Module["jsepRegisterGraphInput"] = - backend["registerGraphInput"].bind(backend); - Module["jsepIsGraphInput"] = backend["isGraphInput"].bind(backend); + Module['jsepRegisterGraphInput'] = backend['registerGraphInput'].bind(backend); + Module['jsepIsGraphInput'] = backend['isGraphInput'].bind(backend); - Module["jsepCreateTemporaryTensor"] = - backend["createTemporaryTensor"].bind(backend); + Module['jsepCreateTemporaryTensor'] = backend['createTemporaryTensor'].bind(backend); } }; diff --git a/onnxruntime/wasm/pre.js b/onnxruntime/wasm/pre.js index 636a9713519a7..9b5f3ce545b78 100644 --- a/onnxruntime/wasm/pre.js +++ b/onnxruntime/wasm/pre.js @@ -1,18 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +'use strict'; + // // This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the // final module using Emscripten's `--pre-js` option. + /** * Mount external data files of a model to an internal map, which will be used during session initialization. * * @param {string} externalDataFilesPath * @param {Uint8Array} externalDataFilesData */ -Module["mountExternalData"] = (externalDataFilePath, externalDataFileData) => { - if (externalDataFilePath.startsWith("./")) { +Module['mountExternalData'] = (externalDataFilePath, externalDataFileData) => { + if (externalDataFilePath.startsWith('./')) { externalDataFilePath = externalDataFilePath.substring(2); } const files = Module.MountedFiles || (Module.MountedFiles = new Map()); @@ -22,7 +25,7 @@ Module["mountExternalData"] = (externalDataFilePath, externalDataFileData) => { /** * Unmount external data files of a model. */ -Module["unmountExternalData"] = () => { +Module['unmountExternalData'] = () => { delete Module.MountedFiles; }; @@ -45,7 +48,5 @@ Module["unmountExternalData"] = () => { * * @suppress {checkVars} */ -var SharedArrayBuffer = - globalThis.SharedArrayBuffer ?? - new WebAssembly.Memory({ initial: 0, maximum: 0, shared: true }).buffer - .constructor; +var SharedArrayBuffer = globalThis.SharedArrayBuffer ?? + new WebAssembly.Memory({'initial': 0, 'maximum': 0, 'shared': true}).buffer.constructor; diff --git a/setup.py b/setup.py index 53e533050b245..ced2f28e38778 100644 --- a/setup.py +++ b/setup.py @@ -356,7 +356,7 @@ def finalize_options(self): "libQnnSaver.so", "libQnnSystem.so", "libHtpPrepare.so", - "ep_weight_sharing_ctx_gen", + "onnxruntime_qnn_ctx_gen", ] dl_libs.extend(qnn_deps) if nightly_build: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index db7dbed23a2d2..8607887072347 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -35,8 +35,7 @@ def version_to_tuple(version: str) -> tuple: import util.android as android # noqa: E402 from util import ( # noqa: E402 generate_android_triplets, - generate_linux_triplets, - generate_macos_triplets, + generate_posix_triplets, generate_vcpkg_triplets_for_emscripten, generate_windows_triplets, get_logger, @@ -1116,6 +1115,7 @@ def generate_build_tree( cmake_extra_args, ): log.info("Generating CMake build tree") + cmake_dir = os.path.join(source_dir, "cmake") cmake_args = [cmake_path, cmake_dir] if not use_dev_mode(args): @@ -1330,16 +1330,8 @@ def generate_build_tree( generate_android_triplets(build_dir, args.android_cpp_shared, args.android_api) elif is_windows(): generate_windows_triplets(build_dir) - elif is_macOS(): - osx_target = args.apple_deploy_target - if args.apple_deploy_target is None: - osx_target = os.environ.get("MACOSX_DEPLOYMENT_TARGET") - if osx_target is not None: - log.info(f"Setting VCPKG_OSX_DEPLOYMENT_TARGET to {osx_target}") - generate_macos_triplets(build_dir, osx_target) else: - # Linux, *BSD, AIX or other platforms - generate_linux_triplets(build_dir) + generate_posix_triplets(build_dir) add_default_definition(cmake_extra_defines, "CMAKE_TOOLCHAIN_FILE", str(vcpkg_toolchain_path)) vcpkg_install_options = generate_vcpkg_install_options(build_dir, args) @@ -1600,11 +1592,8 @@ def generate_build_tree( raise BuildError("WebNN is only available for WebAssembly build.") cmake_args += ["-Donnxruntime_USE_WEBNN=ON"] - # TODO: currently we allows building with both --use_jsep and --use_webgpu in this working branch. - # This situation is temporary. Eventually, those two flags will be mutually exclusive. - # - # if args.use_jsep and args.use_webgpu: - # raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") + if args.use_jsep and args.use_webgpu: + raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") if args.use_external_dawn and not args.use_webgpu: raise BuildError("External Dawn (--use_external_dawn) must be enabled with WebGPU (--use_webgpu).") diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml deleted file mode 100644 index 8aaaa0e85585a..0000000000000 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ /dev/null @@ -1,142 +0,0 @@ -parameters: -- name: CudaVersion - type: string - default: '12.2' - -- name: QnnSdk - displayName: QNN SDK Version - type: string - default: 2.31.0.250130 - -- name: IsReleaseBuild - displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. - type: boolean - default: false - -- name: PackageName - displayName: What is the package name? - type: string - default: 'Microsoft.ML.OnnxRuntime.Flamingo' - -variables: - - template: templates/common-variables.yml - - name: ReleaseVersionSuffix - value: '' - - name: win_cuda_home - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: $(Agent.TempDirectory)\v11.8 - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: $(Agent.TempDirectory)\v12.2 - -stages: - - template: templates/win-ci.yml - parameters: - ort_build_pool_name: 'onnxruntime-Win2022-GPU-A10' - DoCompliance: false - DoEsrp: true - stage_name_suffix: CUDA - buildArch: x64 - msbuildPlatform: x64 - packageName: x64-cuda - CudaVersion: ${{ parameters.CudaVersion }} - buildparameter: --use_cuda --cuda_home=${{ variables.win_cuda_home }} --enable_onnx_tests --enable_wcos --use_webgpu --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual" - runTests: false - buildJava: false - java_artifact_id: onnxruntime_gpu - UseIncreasedTimeoutForTests: false - SpecificArtifact: false - BuildId: '0' - - - template: templates/qnn-ep-win.yml - parameters: - qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QnnSdk: ${{ parameters.QnnSdk }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - DoEsrp: true - ArtifactName: 'drop-nuget-qnn-arm64' - # Add --use_webgpu to enable WebGPU - buildParameter: '--arm64' - buildPlatform: 'ARM64' - buildArch: 'ARM64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' - build_config: 'RelWithDebInfo' - Is1ES: false - PublishArchive: true - - - stage: NugetPackaging - dependsOn: [Windows_Packaging_CUDA, OnnxRuntime_QNN_Nuget_Win_Arm64] - jobs: - - job: CreateNugetPackage - pool: 'Onnxruntime-Win2022-GPU-A10' - timeoutInMinutes: 120 - steps: - - checkout: self - clean: true - submodules: none - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - managed nuget' - inputs: - artifactName: 'drop-nuget-qnn-arm64' - targetPath: '$(Build.BinariesDirectory)/managed-nuget' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - win-x64' - inputs: - artifactName: 'onnxruntime-win-x64-cuda' - targetPath: '$(Build.BinariesDirectory)/win-x64' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - win-arm64' - inputs: - artifactName: 'onnxruntime-win-ARM64-qnn' - targetPath: '$(Build.BinariesDirectory)/win-arm64' - - - task: PowerShell@2 - displayName: 'Extract Nuget Package Version' - inputs: - targetType: 'inline' - script: | - $nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/managed-nuget -Filter Microsoft.ML.OnnxRuntime.Managed.*.nupkg -Recurse) - $package_name = $nupkgs[0].Name - $version_length = $package_name.Length - "Microsoft.ML.OnnxRuntime.Managed.".Length - ".nupkg".Length - $package_version = $package_name.Substring("Microsoft.ML.OnnxRuntime.Managed.".Length, $version_length) - Write-Host "##vso[task.setvariable variable=package_version;]$package_version" - workingDirectory: $(Build.BinariesDirectory) - - - task: PowerShell@2 - displayName: 'Extract Archives' - inputs: - targetType: 'inline' - script: | - Expand-Archive -Path $(Build.BinariesDirectory)/win-x64/onnxruntime-win-x64-cuda*.zip -DestinationPath $(Build.BinariesDirectory)/win-x64 - Expand-Archive -Path $(Build.BinariesDirectory)/win-arm64/onnxruntime-win-ARM64-qnn*.zip -DestinationPath $(Build.BinariesDirectory)/win-arm64 - $win_x64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-x64 -Filter onnxruntime-win-x64-cuda*)[0].FullName - $win_arm64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-arm64 -Filter onnxruntime-win-ARM64-qnn*)[0].FullName - Write-Host "##vso[task.setvariable variable=win_x64;]$win_x64" - Write-Host "##vso[task.setvariable variable=win_arm64;]$win_arm64" - workingDirectory: $(Build.BinariesDirectory) - - - task: PythonScript@0 - displayName: 'Generate Nuget Package' - inputs: - scriptPath: '$(Build.SourcesDirectory)/tools/nuget/generate_nuspec_for_custom_nuget.py' - arguments: '--nuspec_path "$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec" --root_dir "$(Build.SourcesDirectory)" --commit_id "$(Build.SourceVersion)" --win_arm64 "$(win_arm64)" --win_x64 "$(win_x64)" --package_version "$(package_version)" --package_name "${{ parameters.PackageName }}"' - - - task: NuGetCommand@2 - displayName: 'Pack Nuget Package' - inputs: - command: 'pack' - packagesToPack: '$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec' - packDestination: $(Build.ArtifactStagingDirectory)\ - - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: Nuget' - inputs: - pathtoPublish: '$(Build.ArtifactStagingDirectory)' - artifactName: '${{ parameters.PackageName }}' diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index 7a78c6ba0fcdf..a0e49692220f9 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -31,12 +31,10 @@ stages: machine_pool: vmImage: 'macOS-13' itemPattern: '*/*mac*x86_64.whl' - arch: 'x86_64' - template: templates/py-package-smoking-test.yml parameters: job_name: Test_LINUX_x86_64_Wheels itemPattern: '*/*manylinux*x86_64.whl' - arch: 'x86_64' machine_pool: name: 'onnxruntime-Ubuntu2204-AMD-CPU' diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 28ddd29ec63e6..01d30d0e1ba86 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -50,10 +50,10 @@ parameters: displayName: 'Linux packages cmake build type. Linux Only.' default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel # Only applies to QNN packages. - name: qnn_sdk_version @@ -63,33 +63,17 @@ parameters: trigger: none -resources: - repositories: - - repository: 1esPipelines - type: git - name: 1ESPipelineTemplates/1ESPipelineTemplates - ref: refs/tags/release -extends: - # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. - # For non-production pipelines, use "Unofficial" as defined below. - # For productions pipelines, use "Official". - template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines +stages: +- template: stages/py-cpu-packaging-stage.yml parameters: - sdl: - sourceAnalysisPool: - name: onnxruntime-Win-CPU-2022 - os: windows - stages: - - template: stages/py-cpu-packaging-stage.yml - parameters: - enable_linux_cpu: ${{ parameters.enable_linux_cpu }} - enable_windows_cpu: ${{ parameters.enable_windows_cpu }} - enable_mac_cpu: ${{ parameters.enable_mac_cpu }} - enable_linux_arm: ${{ parameters.enable_linux_arm }} - enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} - enable_windows_arm64ec_qnn: ${{ parameters.enable_windows_arm64ec_qnn }} - enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }} - enable_linux_x64_qnn: ${{ parameters.enable_linux_x64_qnn }} - build_py_parameters: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - qnn_sdk_version: ${{ parameters.qnn_sdk_version }} + enable_linux_cpu: ${{ parameters.enable_linux_cpu }} + enable_windows_cpu: ${{ parameters.enable_windows_cpu }} + enable_mac_cpu: ${{ parameters.enable_mac_cpu }} + enable_linux_arm: ${{ parameters.enable_linux_arm }} + enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} + enable_windows_arm64ec_qnn: ${{ parameters.enable_windows_arm64ec_qnn }} + enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }} + enable_linux_x64_qnn: ${{ parameters.enable_linux_x64_qnn }} + build_py_parameters: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + qnn_sdk_version: ${{ parameters.qnn_sdk_version }} diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index cfca998e0f06c..055ef58e4524a 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -29,58 +29,108 @@ parameters: displayName: Pipeline BuildId, you could find it in the URL type: string default: '0' -resources: - repositories: - - repository: 1esPipelines - type: git - name: 1ESPipelineTemplates/1ESPipelineTemplates - ref: refs/tags/release -extends: - # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. - # For non-production pipelines, use "Unofficial" as defined below. - # For productions pipelines, use "Official". - template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + +stages: + +- template: templates/qnn-ep-win.yml parameters: - sdl: - sourceAnalysisPool: - name: onnxruntime-Win-CPU-2022 - os: windows - stages: + qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QnnSdk: ${{ parameters.QnnSdk }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + DoEsrp: ${{ parameters.DoEsrp }} + ArtifactName: 'drop-nuget-qnn-x64' + StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' + build_config: ${{ parameters.build_config }} - - template: templates/qnn-ep-win.yml - parameters: - qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QnnSdk: ${{ parameters.QnnSdk }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-x64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' - build_config: ${{ parameters.build_config }} +- template: templates/qnn-ep-win.yml + parameters: + qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QnnSdk: ${{ parameters.QnnSdk }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + DoEsrp: ${{ parameters.DoEsrp }} + ArtifactName: 'drop-nuget-qnn-arm64' + buildParameter: '--arm64' + buildPlatform: 'ARM64' + buildArch: 'ARM64' + StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' + build_config: ${{ parameters.build_config }} - - template: templates/qnn-ep-win.yml - parameters: - qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QnnSdk: ${{ parameters.QnnSdk }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-arm64' - buildParameter: '--arm64' - buildPlatform: 'ARM64' - buildArch: 'ARM64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' - build_config: ${{ parameters.build_config }} - - - template: stages/nuget-qnn-packaging-stage.yml +- stage: NuGet_Packaging_QNN + pool: 'Onnxruntime-QNNEP-Windows-2022-CPU' + dependsOn: + - OnnxRuntime_QNN_Nuget_Win_x64 + - OnnxRuntime_QNN_Nuget_Win_Arm64 + condition: succeeded() + jobs: + - job: NuGet_Packaging_QNN + workspace: + clean: all + steps: + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - QNN NuGet x64' + inputs: + artifactName: 'drop-nuget-qnn-x64' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact-x64' + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - QNN NuGet arm64' + inputs: + artifactName: 'drop-nuget-qnn-arm64' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64' + + - task: PowerShell@2 + displayName: 'Bundle NuGet' + inputs: + targetType: 'inline' + script: | + + $x64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-x64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) + $nuget_package_name = $x64_nupkgs[0].Name + $x64_nuget_package = $x64_nupkgs[0].FullName + + $nupkg_unzipped_directory = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget_unzip_merged', [System.IO.Path]::GetFileNameWithoutExtension($nuget_package_name)) + + $x64_unzip_cmd = "7z.exe x $x64_nuget_package -y -o$nupkg_unzipped_directory" + Invoke-Expression -Command $x64_unzip_cmd + + $arm64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-arm64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) + $arm64_nuget_package = $arm64_nupkgs[0].FullName + + $arm64_unzip_cmd = "7z.exe x $arm64_nuget_package -y -o$nupkg_unzipped_directory" + Invoke-Expression -Command $arm64_unzip_cmd + + $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget-artifact-merged') + if (!(Test-Path $merged_nuget_path)) { + New-Item -Path $merged_nuget_path -ItemType Directory + } + + $merged_zip = [System.IO.Path]::Combine($merged_nuget_path, 'qnn_nuget.zip') + $zip_cmd = "7z.exe a -r $merged_zip $nupkg_unzipped_directory/*" + Invoke-Expression -Command $zip_cmd + + $merged_nuget = [System.IO.Path]::Combine($merged_nuget_path, $nuget_package_name) + move $merged_zip $merged_nuget + workingDirectory: $(Build.BinariesDirectory) + + - template: templates/esrp_nuget.yml parameters: + DisplayName: 'ESRP - sign NuGet package' + FolderPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' DoEsrp: ${{ parameters.DoEsrp }} - - template: templates/publish-nuget-steps.yml - parameters: - download_artifacts_steps: - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' - ArtifactName: 'drop-signed-nuget-qnn' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline NuGet Artifact' + inputs: + artifactName: 'drop-signed-nuget-qnn' + targetPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' + +- template: templates/publish-nuget-steps.yml + parameters: + download_artifacts_steps: + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' + ArtifactName: 'drop-signed-nuget-qnn' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml new file mode 100644 index 0000000000000..f7f5c7b1494e8 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml @@ -0,0 +1,339 @@ +parameters: +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +- name: UseIncreasedTimeoutForTests + displayName: Increase timeout for tests? Set it to false if you are doing an Onnx Runtime release. + type: boolean + default: false + +- name: DoCompliance + displayName: Run Compliance Tasks? + type: boolean + default: true + +- name: DoEsrp + displayName: Run code sign tasks? Must be true if you are doing an ONNX Runtime release + type: boolean + default: true + +- name: IsReleaseBuild + displayName: Is a release build? Set it to true if you are doing an ONNX Runtime release. + type: boolean + default: false + +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + default: none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + default: 0 + +# these 2 parameters are used for debugging. +- name: SpecificArtifact + displayName: Use Specific Artifact (Debugging only) + type: boolean + default: false + +- name: BuildId + displayName: Pipeline BuildId, you could find it in the URL + type: string + default: '0' + +- name: NugetPackageSuffix + displayName: Suffix to append to nuget package + type: string + default: 'NONE' + +resources: + repositories: + - repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step + type: github + endpoint: ort-examples + name: microsoft/onnxruntime-inference-examples + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + +variables: +- name: ReleaseVersionSuffix + value: '' + +stages: +- template: stages/set_packaging_variables_stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + +# ROCm +- stage: Linux_C_API_Packaging_ROCm_x64 + dependsOn: [] + jobs: + - job: Linux_C_API_Packaging_ROCm_x64 + workspace: + clean: all + timeoutInMinutes: 480 + pool: onnxruntime-Ubuntu2204-AMD-CPU + variables: + RocmVersion: '6.2' + RocmVersionPatchSuffix: '' + steps: + - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime + submodules: recursive + - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux, for get-docker-image-steps.yml + submodules: false + + # get-docker-image-steps.yml will move the $(Build.SourcesDirectory)/manylinux into $(Build.SourcesDirectory)/onnxruntime, + # then rename $(Build.SourcesDirectory)/onnxruntime as $(Build.SourcesDirectory) + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: >- + --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur + --build-arg BUILD_UID=$(id -u) + --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 + --build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix) + --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root + --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: + --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib + Repository: onnxruntimetrainingrocmbuild-rocm$(RocmVersion) + CheckOutManyLinux: true + + - template: templates/set-version-number-variables-step.yml + + - task: Bash@3 + displayName: 'Build' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/build_rocm_c_api_package.sh + arguments: >- + -S $(Build.SourcesDirectory) + -B $(Build.BinariesDirectory) + -V $(RocmVersion) + -I onnxruntimetrainingrocmbuild-rocm$(RocmVersion) + -P python3.10 + + - script: | + set -e -x + mkdir $(Build.ArtifactStagingDirectory)/testdata + cp $(Build.BinariesDirectory)/Release/libcustom_op_library.so* $(Build.ArtifactStagingDirectory)/testdata + ls -al $(Build.ArtifactStagingDirectory) + displayName: 'Create Artifacts for CustomOp' # libcustom_op_library.so from cpu build is built with fp8, ROCm does not support it. + + - template: templates/c-api-artifacts-package-and-publish-steps-posix.yml + parameters: + buildConfig: 'Release' + artifactName: 'onnxruntime-linux-x64-rocm-$(OnnxRuntimeVersion)' + artifactNameNoVersionString: 'onnxruntime-linux-x64-rocm' + libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' + + - template: templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' + - template: templates/clean-agent-build-directory-step.yml + +- stage: NuGet_Packaging_ROCm + dependsOn: + - Setup + - Linux_C_API_Packaging_ROCm_x64 + condition: succeeded() + jobs: + - job: NuGet_Packaging_ROCm + workspace: + clean: all + # we need to use a 2022 pool to create the nuget package with MAUI targets. + # VS2019 has no support for net6/MAUI and we need to use msbuild (from the VS install) to do the packing + pool: 'Onnxruntime-Win-CPU-2022' + variables: + breakCodesignValidationInjection: ${{ parameters.DoEsrp }} + ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] + + steps: + - checkout: self + submodules: true + fetchDepth: 1 + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - NuGet' + ArtifactName: 'onnxruntime-linux-x64-rocm' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - task: PowerShell@2 + displayName: 'Reconstruct Build Directory' + inputs: + targetType: inline + script: | + Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tgz | % { + # *.tar will be created after *.tgz is extracted + $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\nuget-artifact" + Write-Output $cmd + Invoke-Expression -Command $cmd + } + + Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tar | % { + $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" + Write-Output $cmd + Invoke-Expression -Command $cmd + } + + $ort_dirs = Get-ChildItem -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory + foreach ($ort_dir in $ort_dirs) + { + $dirname = Split-Path -Path $ort_dir -Leaf + $dirname = $dirname.SubString(0, $dirname.LastIndexOf('-')) + Write-Output "Renaming $ort_dir to $dirname" + Rename-Item -Path $ort_dir -NewName $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname + } + + Copy-Item -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-rocm\lib\* -Destination $(Build.BinariesDirectory)\RelWithDebInfo + + - script: | + tree /F + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Inspect Build Binaries Directory' + + - script: | + mklink /D /J models C:\local\models + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Create models link' + + - task: NuGetToolInstaller@0 + displayName: Use Nuget 6.10.x + inputs: + versionSpec: 6.10.x + + - task: MSBuild@1 + displayName: 'Restore NuGet Packages and create project.assets.json' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm"' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: MSBuild@1 + displayName: 'Build C# bindings' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: > + -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" + -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm" + -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} + -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + -p:IsLinuxBuild=true + -p:IsWindowsBuild=false + -p:IsMacOSBuild=false + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - template: templates/win-esrp-dll.yml + parameters: + FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' + DisplayName: 'ESRP - Sign C# dlls' + DoEsrp: ${{ parameters.DoEsrp }} + + - task: UsePythonVersion@0 + displayName: 'Use Python' + inputs: + versionSpec: 3.12 + + - task: MSBuild@1 + displayName: 'Build Nuget Packages' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' + configuration: RelWithDebInfo + platform: 'Any CPU' + msbuildArguments: > + -t:CreatePackage + -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" + -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm + -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} + -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + -p:CurrentTime=$(BuildTime) + -p:CurrentDate=$(BuildDate) + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + Contents: '*.snupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - template: templates/esrp_nuget.yml + parameters: + DisplayName: 'ESRP - sign NuGet package' + FolderPath: '$(Build.ArtifactStagingDirectory)' + DoEsrp: ${{ parameters.DoEsrp }} + + - template: templates/validate-package.yml + parameters: + PackageType: 'nuget' + PackagePath: '$(Build.ArtifactStagingDirectory)' + PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' + PlatformsSupported: 'linux-x64' + VerifyNugetSigning: false + + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline NuGet Artifact' + inputs: + artifactName: 'drop-signed-nuget-ROCm' + targetPath: '$(Build.ArtifactStagingDirectory)' + + - task: MSBuild@1 + displayName: 'Clean C#' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - template: templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' + + +- template: nuget/templates/test_linux.yml + parameters: + AgentPool: AMD-GPU + ArtifactSuffix: 'ROCm' + StageSuffix: 'ROCm' + NugetPackageName: 'Microsoft.ML.OnnxRuntime.ROCm' + SpecificArtifact: ${{ parameters.specificArtifact }} + CustomOpArtifactName: 'onnxruntime-linux-x64-rocm' + BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml b/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml new file mode 100644 index 0000000000000..1d2393d8f96d5 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml @@ -0,0 +1,21 @@ +resources: + pipelines: + - pipeline: build + source: 'Nuget ROCM Packaging pipeline' + trigger: + branches: + include: + - main + - rel-* + branch: main + +# ROCm +stages: +- template: templates/publish-nuget-steps.yml + parameters: + stage_name: 'Publish_ROCM_NuGet_Package' + download_artifacts_steps: + - download: build + displayName: 'Download Pipeline Artifact - Signed NuGet Package' + artifact: 'drop-signed-nuget-ROCm' + - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-ROCm\*" $(Build.BinariesDirectory)\nuget-artifact\final-package diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index 5ae60aac8f9b4..8fabb80a73869 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -96,10 +96,18 @@ stages: inputs: versionSpec: 6.10.x + - task: PowerShell@2 + displayName: Install MAUI workloads + inputs: + targetType: 'inline' + script: | + dotnet workload install android ios maccatalyst + workingDirectory: '$(Build.SourcesDirectory)\csharp' + - task: MSBuild@1 displayName: 'Restore NuGet Packages and create project.assets.json' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu"' @@ -108,7 +116,7 @@ stages: - task: MSBuild@1 displayName: 'Build C# bindings' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' configuration: RelWithDebInfo platform: 'Any CPU' msbuildArguments: > @@ -200,7 +208,7 @@ stages: - task: MSBuild@1 displayName: 'Clean C#' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' @@ -215,3 +223,4 @@ stages: inputs: artifactName: 'drop-signed-nuget-GPU' targetPath: '$(Build.ArtifactStagingDirectory)' + diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml deleted file mode 100644 index 03802746cec3d..0000000000000 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml +++ /dev/null @@ -1,76 +0,0 @@ -parameters: -- name: DoEsrp - displayName: Run code sign tasks? Must be true if you are doing an Onnx Runtime release. - type: boolean - default: true - -stages: -- stage: NuGet_Packaging_QNN - pool: - name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - dependsOn: - - OnnxRuntime_QNN_Nuget_Win_x64 - - OnnxRuntime_QNN_Nuget_Win_Arm64 - condition: succeeded() - jobs: - - job: NuGet_Packaging_QNN - workspace: - clean: all - steps: - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - QNN NuGet x64' - inputs: - artifactName: 'drop-nuget-qnn-x64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-x64' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - QNN NuGet arm64' - inputs: - artifactName: 'drop-nuget-qnn-arm64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64' - - - task: PowerShell@2 - displayName: 'Bundle NuGet' - inputs: - targetType: 'inline' - script: | - - $x64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-x64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) - $nuget_package_name = $x64_nupkgs[0].Name - $x64_nuget_package = $x64_nupkgs[0].FullName - - $nupkg_unzipped_directory = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget_unzip_merged', [System.IO.Path]::GetFileNameWithoutExtension($nuget_package_name)) - - $x64_unzip_cmd = "7z.exe x $x64_nuget_package -y -o$nupkg_unzipped_directory" - Invoke-Expression -Command $x64_unzip_cmd - - $arm64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-arm64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) - $arm64_nuget_package = $arm64_nupkgs[0].FullName - - $arm64_unzip_cmd = "7z.exe x $arm64_nuget_package -y -o$nupkg_unzipped_directory" - Invoke-Expression -Command $arm64_unzip_cmd - - $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget-artifact-merged') - if (!(Test-Path $merged_nuget_path)) { - New-Item -Path $merged_nuget_path -ItemType Directory - } - - $merged_zip = [System.IO.Path]::Combine($merged_nuget_path, 'qnn_nuget.zip') - $zip_cmd = "7z.exe a -r $merged_zip $nupkg_unzipped_directory/*" - Invoke-Expression -Command $zip_cmd - - $merged_nuget = [System.IO.Path]::Combine($merged_nuget_path, $nuget_package_name) - move $merged_zip $merged_nuget - workingDirectory: $(Build.BinariesDirectory) - - - template: ../templates/esrp_nuget.yml - parameters: - DisplayName: 'ESRP - sign NuGet package' - FolderPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' - DoEsrp: ${{ parameters.DoEsrp }} - - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-qnn' - targetPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 5e783607e3622..4ff539df9f914 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -123,7 +123,7 @@ stages: --skip_submodule_sync --cmake_generator "Visual Studio 17 2022" --enable_pybind - --enable_onnx_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache + --enable_onnx_tests ${{ parameters.build_py_parameters }} --parallel --use_binskim_compliant_compile_flags --update --build $(TelemetryOption) @@ -151,11 +151,10 @@ stages: Contents: '*.whl' TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: 1ES.PublishPipelineArtifact@1 + - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - artifactName: onnxruntime-win-$(PythonVersion) - targetPath: '$(Build.ArtifactStagingDirectory)' + ArtifactName: onnxruntime - script: | 7z x *.whl @@ -200,9 +199,7 @@ stages: workspace: clean: all pool: - name: "Azure Pipelines" - image: "macOS-13" - os: macOS + vmImage: 'macOS-13' variables: MACOSX_DEPLOYMENT_TARGET: '13.3' strategy: @@ -254,81 +251,74 @@ stages: Contents: '*.whl' TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: 1ES.PublishPipelineArtifact@1 + - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - artifactName: onnxruntime-macos-$(PythonVersion) - targetPath: '$(Build.ArtifactStagingDirectory)' + ArtifactName: onnxruntime - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' -- ${{ if eq(parameters.enable_linux_arm, true) }}: - - stage: Python_Packaging_Linux_ARM - dependsOn: [] - jobs: + - ${{ if eq(parameters.enable_linux_arm, true) }}: + - stage: Python_Packaging_Linux_ARM + dependsOn: [] + jobs: + - template: ../templates/py-linux.yml + parameters: + arch: 'aarch64' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + + - ${{ if eq(parameters.enable_linux_cpu, true) }}: + - stage: Python_Packaging_Linux_CPU + dependsOn: [] + jobs: - template: ../templates/py-linux.yml parameters: - arch: 'aarch64' - machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' + arch: 'x86_64' + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} - is1ES: true -- ${{ if eq(parameters.enable_linux_cpu, true) }}: - - stage: Python_Packaging_Linux_CPU - dependsOn: [] - jobs: - - template: ../templates/py-linux.yml - parameters: - arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' - extra_build_arg: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - is1ES: true - -- ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: - - stage: Python_Packaging_Windows_ARM64_QNN - dependsOn: [] - jobs: - - template: ../templates/py-win-arm64-qnn.yml - parameters: - MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' - QNN_SDK: ${{ parameters.qnn_sdk_version }} - BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true - -- ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: - - stage: Python_Packaging_Windows_arm64ec_QNN - dependsOn: [] - jobs: - - template: ../templates/py-win-arm64ec-qnn.yml + - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: + - stage: Python_Packaging_Windows_ARM64_QNN + dependsOn: [] + jobs: + - template: ../templates/py-win-arm64-qnn.yml parameters: - MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true -- ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: - - stage: Python_Packaging_Windows_x64_QNN - dependsOn: [] - jobs: - - template: ../templates/py-win-x64-qnn.yml + - ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: + - stage: Python_Packaging_Windows_arm64ec_QNN + dependsOn: [] + jobs: + - template: ../templates/py-win-arm64ec-qnn.yml + parameters: + MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + + - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: + - stage: Python_Packaging_Windows_x64_QNN + dependsOn: [] + jobs: + - template: ../templates/py-win-x64-qnn.yml + parameters: + MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + + - ${{ if eq(parameters.enable_linux_x64_qnn, true) }}: + - stage: Python_Packaging_Linux_x64_QNN + dependsOn: [] + jobs: + - template: ../templates/py-linux-qnn.yml parameters: - MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QNN_SDK: ${{ parameters.qnn_sdk_version }} - BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true - -- ${{ if eq(parameters.enable_linux_x64_qnn, true) }}: - - stage: Python_Packaging_Linux_x64_QNN - dependsOn: [] - jobs: - - template: ../templates/py-linux-qnn.yml - parameters: - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' - extra_build_arg: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - is1ES: true + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml index e1a514ea54123..5ee425405ac70 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml @@ -57,22 +57,6 @@ steps: copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.pdb $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.lib $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - # Copy WebGPU dependencies if required - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\dxcompiler.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\dxil.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - - # Copy QNN dependencies if required - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_qnn.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\libQnnHtp*.so $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib /Y - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\libqnnhtp*.cat $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib /Y - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnCpu.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtp.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpPrepare.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpV68Stub.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpV73Stub.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnSaver.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnSystem.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - # copy trt ep libraries only when trt ep is enabled copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_tensorrt.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_tensorrt.pdb $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml deleted file mode 100644 index f6956b426ddfc..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml +++ /dev/null @@ -1,64 +0,0 @@ -parameters: - - name: OpenVINOVersion - type: string - default: '2025.0.0' - -steps: - - powershell: | - $Url = "https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.0/windows/openvino_toolkit_windows_2025.0.0.17942.1f68be9f594_x86_64.zip" - $OutputPath = "$env:Agent_TempDirectory\openvino.zip" - $ExtractPath = "$env:Agent_TempDirectory\openvino-v$env:OpenVINOVersion" - $TempExtractPath = "$env:Agent_TempDirectory\openvino_temp" - - # Ensure directories exist - if (Test-Path $ExtractPath) { - Remove-Item -Recurse -Force $ExtractPath - } - New-Item -ItemType Directory -Path $ExtractPath | Out-Null - New-Item -ItemType Directory -Path $TempExtractPath | Out-Null - - # Download OpenVINO ZIP - Write-Output "Downloading OpenVINO" - Invoke-WebRequest -Uri $Url -OutFile $OutputPath - - # Extract to temporary directory first - Write-Output "Extracting OpenVINO to a temporary directory" - Expand-Archive -Path $OutputPath -DestinationPath $TempExtractPath -Force - - # Locate the nested subdirectory - $InnerFolder = Get-ChildItem -Path $TempExtractPath -Directory | Select-Object -First 1 - - if ($InnerFolder) { - Write-Output "Moving extracted files to final destination" - Move-Item -Path "$($InnerFolder.FullName)\*" -Destination $ExtractPath -Force - } else { - Write-Error "Extraction failed: No expected subdirectory found in $TempExtractPath." - Write-Error "The archive may not have extracted correctly, or its structure is different than expected." - exit 1 - } - - # Clean up temporary files - Remove-Item -Recurse -Force $TempExtractPath - Remove-Item -Force $OutputPath - - # Confirm success - Write-Output "OpenVINO extracted to $ExtractPath" - displayName: 'Download OpenVINO Toolkit v${{ parameters.OpenVINOVersion }}' - env: - OpenVINOVersion: ${{ parameters.OpenVINOVersion }} - - - powershell: | - echo "##vso[task.setvariable variable=OpenVINORootDir]$(Agent.TempDirectory)\openvino-v${{ parameters.OpenVINOVersion }}" - displayName: 'Set OpenVINORootDir' - - - task: CmdLine@2 - inputs: - script: | - echo $(OpenVINORootDir) - displayName: 'Print OpenVINORootDir after downloading OpenVINO' - - - task: CmdLine@2 - displayName: 'Print contents of OpenVINO Toolkit' - inputs: - script: | - dir $(OpenVINORootDir) diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml index 2b73f82615bba..a4d5a73118ea2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml @@ -1,8 +1,4 @@ steps: -- task: NodeTool@0 - inputs: - # requires Node.js v22 for float16 testing (the V8 flag "--js-float16array") - versionSpec: '22.x' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)/js' @@ -15,10 +11,6 @@ steps: npm test workingDirectory: '$(Build.SourcesDirectory)/js/common' displayName: 'run onnxruntime-common tests' -- script: | - npm run test:f16 - workingDirectory: '$(Build.SourcesDirectory)/js/common' - displayName: 'run onnxruntime-common tests (enable Float16Array)' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)/js/web' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index 8126cda449daa..347a3145e8c70 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -6,10 +6,10 @@ parameters: type: string default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel - name: device type: string @@ -27,82 +27,68 @@ parameters: displayName: QNN SDK version type: string default: 2.31.0.250130 - -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false jobs: - job: Linux_py_qnn_Wheels_x64 timeoutInMinutes: 240 workspace: clean: all - pool: - name: ${{ parameters.machine_pool }} - os: linux + pool: ${{ parameters.machine_pool }} variables: - # The build machine pool doesn't have dotnet, so it can't run CG. - - name: skipComponentGovernanceDetection - value: true - - name: ORT_CACHE_DIR - value: $(Agent.TempDirectory)/ort_ccache - - name: TODAY - value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - - name: extra_build_args - ${{ if ne(parameters.extra_build_arg, '') }}: - value: -x ${{ parameters.extra_build_arg }} - ${{ if eq(parameters.extra_build_arg, '') }}: - value: '' + # The build machine pool doesn't have dotnet, so it can't run CG. + - name: skipComponentGovernanceDetection + value: true + - name: ORT_CACHE_DIR + value: $(Agent.TempDirectory)/ort_ccache + - name: TODAY + value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: -x ${{ parameters.extra_build_arg }} + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' steps: - - checkout: self - clean: true - submodules: none + - checkout: self + clean: true + submodules: none - - template: jobs/download_linux_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QnnSdk }} + - template: jobs/download_linux_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QnnSdk }} - - template: set-nightly-build-option-variable-step.yml + - template: set-nightly-build-option-variable-step.yml - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecpubuildpythonx86_64_qnn + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile + Context: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimecpubuildpythonx86_64_qnn - - template: linux-build-step-with-cache.yml - parameters: - WithCache: ${{parameters.with_cache}} - Today: $(TODAY) - AdditionalKey: Linux_py_qnn_Wheels_x64 - CacheDir: $(ORT_CACHE_DIR) - ChangeEveryCommit: true - BuildStep: - - task: Bash@3 - displayName: 'Build Python Wheel' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpythonx86_64_qnn -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) - env: - ADDITIONAL_DOCKER_PARAMETER: "--volume $(QnnSDKRootDir):/qnn_sdk" - - ${{ if eq(parameters.is1ES, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Artifact: Linux ONNXRuntime QNN python wheel' - inputs: - targetPath: '$(Build.BinariesDirectory)/dist' - artifactName: onnxruntime-linux-qnn-x64 + - template: linux-build-step-with-cache.yml + parameters: + WithCache: ${{parameters.with_cache}} + Today: $(TODAY) + AdditionalKey: Linux_py_qnn_Wheels_x64 + CacheDir: $(ORT_CACHE_DIR) + ChangeEveryCommit: true + BuildStep: + - task: Bash@3 + displayName: 'Build Python Wheel' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh + arguments: -i onnxruntimecpubuildpythonx86_64_qnn -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) + env: + ADDITIONAL_DOCKER_PARAMETER: "--volume $(QnnSDKRootDir):/qnn_sdk" - - ${{ if eq(parameters.is1ES, false) }}: - - task: PublishPipelineArtifact@1 + - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: Linux ONNXRuntime QNN python wheel' inputs: - targetPath: '$(Build.BinariesDirectory)/dist' - artifactName: onnxruntime-linux-qnn-x64 + PathtoPublish: '$(Build.BinariesDirectory)/dist' + ArtifactName: onnxruntime-linux-qnn-x64 - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' + - template: component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml index 8d0c4334f4874..e591b719ecfa9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml @@ -9,10 +9,10 @@ parameters: type: string default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel - name: device type: string @@ -34,98 +34,76 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: - job: Linux_py_Wheels_${{ parameters.arch }}_${{parameters.ep}} timeoutInMinutes: 240 workspace: clean: all - pool: - name: ${{ parameters.machine_pool }} - os: 'linux' - ${{ if eq(parameters.arch, 'aarch64') }}: - hostArchitecture: Arm64 + pool: ${{ parameters.machine_pool }} variables: - # The build machine pool doesn't have dotnet, so it can't run CG. - - name: skipComponentGovernanceDetection - value: true - - name: ORT_CACHE_DIR - value: $(Agent.TempDirectory)/ort_ccache - - name: TODAY - value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - - name: extra_build_args - ${{ if ne(parameters.extra_build_arg, '') }}: - value: '-x ${{ parameters.extra_build_arg }}' - ${{ if eq(parameters.extra_build_arg, '') }}: - value: '' - - name: python_exe_path - ${{ if ne(parameters.python_exe_path, '') }}: - value: '-p ${{ parameters.python_exe_path }}' - ${{ if eq(parameters.python_exe_path, '') }}: - value: '' + # The build machine pool doesn't have dotnet, so it can't run CG. + - name: skipComponentGovernanceDetection + value: true + - name: ORT_CACHE_DIR + value: $(Agent.TempDirectory)/ort_ccache + - name: TODAY + value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: '-x ${{ parameters.extra_build_arg }}' + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' + - name: python_exe_path + ${{ if ne(parameters.python_exe_path, '') }}: + value: '-p ${{ parameters.python_exe_path }}' + ${{ if eq(parameters.python_exe_path, '') }}: + value: '' steps: - - checkout: self - clean: true - submodules: none - - - template: set-nightly-build-option-variable-step.yml - - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecpubuildpython${{ parameters.arch }} - - - template: linux-build-step-with-cache.yml - parameters: - WithCache: ${{parameters.with_cache}} - Today: $(TODAY) - AdditionalKey: Linux_py_Wheels_${{ parameters.arch }} - CacheDir: $(ORT_CACHE_DIR) - ChangeEveryCommit: true - BuildStep: - - task: Bash@3 - displayName: 'Build Python Wheel' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) $(python_exe_path) - ${{ if eq(parameters.with_cache, 'true') }}: - env: - ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" - - - ${{ if eq(parameters.is1ES, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - targetPath: '$(Build.BinariesDirectory)/dist' - artifactName: onnxruntime-${{ parameters.arch }}-${{ parameters.ep }} - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Test Binaries' - inputs: - artifactName: 'drop-linux-cpu-${{ parameters.arch }}-${{ parameters.ep }}' - targetPath: '$(Build.BinariesDirectory)/${{ parameters.cmake_build_type }}' - - ${{ if eq(parameters.is1ES, false) }}: - - task: PublishPipelineArtifact@1 + - checkout: self + clean: true + submodules: none + + - template: set-nightly-build-option-variable-step.yml + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu/Dockerfile + Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimecpubuildpython${{ parameters.arch }} + + - template: linux-build-step-with-cache.yml + parameters: + WithCache: ${{parameters.with_cache}} + Today: $(TODAY) + AdditionalKey: Linux_py_Wheels_${{ parameters.arch }} + CacheDir: $(ORT_CACHE_DIR) + ChangeEveryCommit: true + BuildStep: + - task: Bash@3 + displayName: 'Build Python Wheel' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh + arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) $(python_exe_path) + ${{ if eq(parameters.with_cache, 'true') }}: + env: + ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" + + - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - targetPath: '$(Build.BinariesDirectory)/dist' - artifactName: onnxruntime-${{ parameters.arch }}-${{ parameters.ep }} - - task: PublishPipelineArtifact@1 + PathtoPublish: '$(Build.BinariesDirectory)/dist' + ArtifactName: onnxruntime-${{ parameters.ep }} + + - task: PublishPipelineArtifact@0 displayName: 'Publish Test Binaries' inputs: artifactName: 'drop-linux-cpu-${{ parameters.arch }}-${{ parameters.ep }}' targetPath: '$(Build.BinariesDirectory)/${{ parameters.cmake_build_type }}' - - - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' + - template: component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index c0bd740b2d483..3a3da0f8f5afa 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -9,13 +9,9 @@ parameters: - name: machine_pool type: object -- name: ep +- name: python_arch type: string - default: 'cpu' - -- name: arch - type: string - default: 'x86_64' + default: 'x64' jobs: - job: ${{ parameters.job_name }} @@ -41,9 +37,10 @@ jobs: displayName: 'Use Python' inputs: versionSpec: $(PythonVersion) + architecture: ${{ parameters.python_arch }} - download: build # pipeline resource identifier. - artifact: 'onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}' + artifact: 'onnxruntime' - task: Bash@3 inputs: @@ -54,7 +51,7 @@ jobs: FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}" $PYTHON_PACKAGE_NAME + python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime" $PYTHON_PACKAGE_NAME python3 -m pip show $PYTHON_PACKAGE_NAME python3 -c "import onnxruntime as ort; print(ort.__version__)" workingDirectory: $(Pipeline.Workspace)/build/onnxruntime diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml index eef97341b8d53..c475feaef0018 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml @@ -19,10 +19,10 @@ parameters: type: string default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel - name: timeout type: number @@ -50,31 +50,29 @@ jobs: artifact: 'drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}' - download: current # pipeline resource identifier. - artifact: 'onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}' + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}-${{ parameters.ep }}' - bash: | set -e -x mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} - mv "$(Pipeline.Workspace)/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}" "$(Build.BinariesDirectory)/whl" + mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}-${{parameters.ep}}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; - displayName: 'Move the artifacts to the binaries directory' # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - download: build # pipeline resource identifier. artifact: 'drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}' - download: build # pipeline resource identifier. - artifact: 'onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}' + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}-${{ parameters.ep }}' - bash: | set -e -x ls $(Pipeline.Workspace)/build mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} - mv "$(Pipeline.Workspace)/build/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}" "$(Build.BinariesDirectory)/whl" + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}-${{parameters.ep}}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; - displayName: 'Move the artifacts to the binaries directory' # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet - ${{ if eq(parameters.arch, 'x86_64') }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 10ea7f6203bb1..4c9d0dccaf48d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -19,11 +19,6 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: - job: Win_py_arm64_qnn_Wheels timeoutInMinutes: 210 @@ -31,8 +26,6 @@ jobs: clean: all pool: name: ${{ parameters.MACHINE_POOL }} - os: windows - hostArchitecture: Arm64 strategy: matrix: Python311_arm64: @@ -48,140 +41,132 @@ jobs: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - - checkout: self - clean: true - submodules: recursive - - - template: telemetry-steps.yml - - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - displayName: Copy python $(PythonVersion) version to agent tools directory - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: 'arm64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - task: onebranch.pipeline.tsaoptions@1 - displayName: 'OneBranch TSAOptions' - inputs: - tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' - appendSourceBranchName: false - - - task: PythonScript@0 - inputs: - scriptSource: inline - script: | - import subprocess - subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel']) - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Install python modules' - - - template: set-nightly-build-option-variable-step.yml - - - template: jobs/download_win_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QNN_SDK }} - - - task: PythonScript@0 - displayName: 'Generate cmake config' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: > - --config RelWithDebInfo - --build_dir $(Build.BinariesDirectory) - --skip_submodule_sync - --cmake_generator "$(VSGenerator)" - --build_shared_lib - --use_qnn - --qnn_home $(QnnSDKRootDir) - --enable_pybind - --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update - $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} - workingDirectory: '$(Build.BinariesDirectory)' - - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' - platform: 'arm64' - configuration: RelWithDebInfo - msbuildArchitecture: 'arm64' - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - createLogFile: true - - # Esrp signing - - template: win-esrp-dll.yml - parameters: - FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' - DisplayName: 'ESRP - Sign Native dlls' - DoEsrp: true - Pattern: '*.pyd' - - - task: PythonScript@0 - displayName: 'Build wheel' - inputs: - scriptPath: '$(Build.SourcesDirectory)\setup.py' - arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' - workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' - Contents: '*.whl' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - ${{ if eq(parameters.is1ES, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - artifactName: onnxruntime_qnn_arm64_$(PythonVersion) - targetPath: '$(Build.ArtifactStagingDirectory)' - - ${{ if eq(parameters.is1ES, false) }}: - - task: PublishPipelineArtifact@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - input: - artifactName: onnxruntime_qnn_arm64_$(PythonVersion) - targetPath: '$(Build.ArtifactStagingDirectory)' - - - script: | - 7z x *.whl - workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'unzip the package' - - - task: CredScan@3 - displayName: 'Run CredScan' - inputs: - debugMode: false - continueOnError: true - - - task: BinSkim@4 - displayName: 'Run BinSkim' - inputs: - AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 + XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ + COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete + DIR $(Agent.ToolsDirectory)\Python + DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) + DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 + displayName: Copy python $(PythonVersion) version to agent tools directory + + - task: UsePythonVersion@0 + inputs: + versionSpec: $(PythonVersion) + addToPath: true + architecture: 'arm64' + + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel']) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: set-nightly-build-option-variable-step.yml + + - template: jobs/download_win_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QNN_SDK }} + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --build_shared_lib + --use_qnn + --qnn_home $(QnnSDKRootDir) + --enable_pybind + --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'arm64' + configuration: RelWithDebInfo + msbuildArchitecture: 'arm64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn_arm64 + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 24321d2a3e1ec..ed29f1e67515e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -19,11 +19,6 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 @@ -31,7 +26,6 @@ jobs: clean: all pool: name: ${{ parameters.MACHINE_POOL }} - os: windows strategy: matrix: Python310_x64: @@ -46,124 +40,117 @@ jobs: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - - checkout: self - clean: true - submodules: recursive - - - template: telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: 'x64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - task: onebranch.pipeline.tsaoptions@1 - displayName: 'OneBranch TSAOptions' - inputs: - tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' - appendSourceBranchName: fals - - - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\linux\python\requirements.txt - - - - template: set-nightly-build-option-variable-step.yml - - - template: jobs/download_win_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QNN_SDK }} - - - task: PythonScript@0 - displayName: 'Generate cmake config' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: > - --config RelWithDebInfo - --build_dir $(Build.BinariesDirectory) - --skip_submodule_sync - --cmake_generator "$(VSGenerator)" - --build_shared_lib - --use_qnn - --qnn_home $(QnnSDKRootDir) - --enable_pybind - --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update --arm64ec - $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} - workingDirectory: '$(Build.BinariesDirectory)' - - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' - platform: 'arm64ec' - configuration: RelWithDebInfo - msbuildArchitecture: 'x64' - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - createLogFile: true - - # Esrp signing - - template: win-esrp-dll.yml - parameters: - FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' - DisplayName: 'ESRP - Sign Native dlls' - DoEsrp: true - Pattern: '*.pyd' - - - task: PythonScript@0 - displayName: 'Build wheel' - inputs: - scriptPath: '$(Build.SourcesDirectory)\setup.py' - arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' - workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' - Contents: '*.whl' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - ${{ if eq(parameters.is1ES, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - artifactName: onnxruntime_qnn_arm64ec_$(PythonVersion) - targetPath: '$(Build.ArtifactStagingDirectory)' - - ${{ if eq(parameters.is1ES, false) }}: - - task: PublishPipelineArtifact@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - artifactName: onnxruntime_qnn_arm64ec_$(PythonVersion) - targetPath: '$(Build.ArtifactStagingDirectory)' - - script: | - 7z x *.whl - workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'unzip the package' - - - task: CredScan@3 - displayName: 'Run CredScan' - inputs: - debugMode: false - continueOnError: true - - - task: BinSkim@4 - displayName: 'Run BinSkim' - inputs: - AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - task: UsePythonVersion@0 + inputs: + versionSpec: $(PythonVersion) + addToPath: true + architecture: 'x64' + + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: fals + + - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\linux\python\requirements.txt + + + - template: set-nightly-build-option-variable-step.yml + + - template: jobs/download_win_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QNN_SDK }} + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --build_shared_lib + --use_qnn + --qnn_home $(QnnSDKRootDir) + --enable_pybind + --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update --arm64ec + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'arm64ec' + configuration: RelWithDebInfo + msbuildArchitecture: 'x64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn_arm64ec + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 175b343e55d57..13069846da342 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -19,11 +19,6 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 @@ -121,18 +116,10 @@ jobs: Contents: '*.whl' TargetFolder: '$(Build.ArtifactStagingDirectory)' - - ${{ if eq(parameters.is1ES, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - artifactName: onnxruntime_qnn_x64_$(PythonVersion) - targetPath: '$(Build.ArtifactStagingDirectory)' - - ${{ if eq(parameters.is1ES, false) }}: - - task: PublishPipelineArtifact@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - artifactName: onnxruntime_qnn_x64_$(PythonVersion) - targetPath: '$(Build.ArtifactStagingDirectory)' + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn_x64 - script: | 7z x *.whl diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 3fa4799ec9c0e..a93d6b5ff8419 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -10,8 +10,6 @@ parameters: buildPlatform: 'x64' buildArch: 'x64' StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' - Is1ES: true - PublishArchive: false stages: - stage: ${{ parameters.StageName }} @@ -20,8 +18,7 @@ stages: - job: ${{ parameters.StageName }} timeoutInMinutes: 120 - pool: - name: ${{ parameters.qnn_ep_build_pool_name }} + pool: ${{ parameters.qnn_ep_build_pool_name }} variables: ${{ if eq(parameters.buildArch, 'ARM64') }}: targetArchitecture: 'arm64' @@ -31,148 +28,133 @@ stages: commonBuildArgs: '--update --compile_no_warning_as_error --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_binskim_compliant_compile_flags ${{ parameters.buildParameter }} ' steps: - - template: set-version-number-variables-step.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - - - template: jobs/download_win_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QnnSdk }} - - - task: PythonScript@0 - displayName: 'Generate project' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' - - - task: VSBuild@1 - displayName: 'Build onnxruntime' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnx_test_runner' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnx_test_runner.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnxruntime_perf_test' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_perf_test.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnxruntime_test_all (to copy Qnn libs)' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_test_all.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: CmdLine@2 - displayName: 'Print contents of binaries directory' - inputs: - script: | - dir $(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }} - - - template: win-esrp-dll.yml - parameters: - FolderPath: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' - DisplayName: 'ESRP - Sign dlls' - DoEsrp: ${{ parameters.DoEsrp }} - Pattern: 'onnxruntime*.dll' - - - ${{ if eq(parameters.PublishArchive, true) }}: - - template: c-api-artifacts-package-and-publish-steps-windows.yml + - template: set-version-number-variables-step.yml + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + + - template: jobs/download_win_qnn_sdk.yml parameters: - buildConfig: ${{ parameters.build_config }} - artifactName: 'onnxruntime-win-${{ parameters.buildPlatform }}-qnn' - artifactNameNoVersionString: 'onnxruntime-win-${{ parameters.buildPlatform }}-qnn' - DoEsrp: ${{ parameters.DoEsrp }} + QnnSDKVersion: ${{ parameters.QnnSdk }} + + - task: PythonScript@0 + displayName: 'Generate project' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '--use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' + + - task: VSBuild@1 + displayName: 'Build onnxruntime' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: VSBuild@1 + displayName: 'Build onnx_test_runner' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnx_test_runner.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: VSBuild@1 + displayName: 'Build onnxruntime_perf_test' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_perf_test.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: VSBuild@1 + displayName: 'Build onnxruntime_test_all (to copy Qnn libs)' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_test_all.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: CmdLine@2 + displayName: 'Print contents of binaries directory' + inputs: + script: | + dir $(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }} - - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' - platform: 'Any CPU' - configuration: ${{ parameters.build_config }} - msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: MSBuild@1 - displayName: 'Build C# bindings' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' - platform: 'Any CPU' - configuration: ${{ parameters.build_config }} - msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - ${{ if eq(parameters.DoEsrp, true) }}: - template: win-esrp-dll.yml parameters: - FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\${{ parameters.build_config }}' - DisplayName: 'ESRP - Sign C# dlls' + FolderPath: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' + DisplayName: 'ESRP - Sign dlls' DoEsrp: ${{ parameters.DoEsrp }} + Pattern: 'onnxruntime*.dll' - - task: MSBuild@1 - displayName: 'Build Nuget Packages' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' - platform: 'Any CPU' - configuration: ${{ parameters.build_config }} - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=$(targetArchitecture)' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: CopyFiles@2 - displayName: 'Copy native nuget package to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy native nuget symbols package to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' - Contents: '*.snupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - ${{ if eq(parameters.Is1ES, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Pipeline x64 NuGet Artifact' + - task: MSBuild@1 + displayName: 'Restore NuGet Packages and create project.assets.json' inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - - ${{ else }}: - - task: PublishPipelineArtifact@1 + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' + platform: 'Any CPU' + configuration: ${{ parameters.build_config }} + msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: MSBuild@1 + displayName: 'Build C# bindings' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' + platform: 'Any CPU' + configuration: ${{ parameters.build_config }} + msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - ${{ if eq(parameters.DoEsrp, true) }}: + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\${{ parameters.build_config }}' + DisplayName: 'ESRP - Sign C# dlls' + DoEsrp: ${{ parameters.DoEsrp }} + + - task: MSBuild@1 + displayName: 'Build Nuget Packages' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' + platform: 'Any CPU' + configuration: ${{ parameters.build_config }} + msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=$(targetArchitecture)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: CopyFiles@2 + displayName: 'Copy native nuget package to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy native nuget symbols package to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' + Contents: '*.snupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishPipelineArtifact@0 displayName: 'Publish Pipeline x64 NuGet Artifact' inputs: artifactName: ${{ parameters.ArtifactName }} diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 52dbb76632e0c..7991916a47ca4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -62,14 +62,10 @@ stages: dependsOn: '${{parameters.InitialStageDependsOn}}' jobs: - job: ReactNative_CI_iOS - ${{ if eq(parameters.is1ES, false) }}: - pool: - vmImage: 'macOS-13' - ${{ if eq(parameters.is1ES, true) }}: - pool: - name: 'Azure Pipelines' - image: 'macOS-13' - os: 'macOS' + pool: + name: 'Azure Pipelines' + image: 'macOS-13' + os: 'macOS' timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index 2e3589ee87c29..87836880cbdb8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -83,6 +83,9 @@ stages: git submodule update --init -- cmake/external/onnx workingDirectory: '$(Build.SourcesDirectory)' displayName: 'Checkout submodule onnx' + - task: NodeTool@0 + inputs: + versionSpec: '20.x' - template: linux-web-init-and-check.yml - task: Bash@3 displayName: 'Extract commit SHA and save to __commit.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 69a06c3db24b8..600e6d857185f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -161,7 +161,7 @@ stages: displayName: 'Generate cmake config' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --build --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} $(timeoutParameter) $(buildJavaParameter)' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --build --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} $(timeoutParameter) $(buildJavaParameter)' workingDirectory: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 6868043f64d81..b77cab6a19ba0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -88,18 +88,10 @@ jobs: inputs: sourceFolder: $(Pipeline.Workspace)\artifacts contents: | - **\ort-*.wasm + **\*.* targetFolder: $(Build.SourcesDirectory)\js\web\dist flattenFolders: true - displayName: 'Binplace dist files (.wasm)' - - task: CopyFiles@2 - inputs: - sourceFolder: $(Pipeline.Workspace)\artifacts - contents: | - **\ort-*.mjs - targetFolder: $(Build.SourcesDirectory)\js\web\dist - flattenFolders: true - displayName: 'Binplace dist files (.mjs)' + displayName: 'Binplace dist files' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml index 00df695889b1d..e201cc0ffdd5a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml @@ -44,18 +44,10 @@ jobs: inputs: sourceFolder: $(Pipeline.Workspace)\artifacts contents: | - **\ort-*.wasm + **\*.* targetFolder: $(Build.SourcesDirectory)\js\web\dist flattenFolders: true - displayName: 'Binplace dist files (.wasm)' - - task: CopyFiles@2 - inputs: - sourceFolder: $(Pipeline.Workspace)\artifacts - contents: | - **\ort-*.mjs - targetFolder: $(Build.SourcesDirectory)\js\web\dist - flattenFolders: true - displayName: 'Binplace dist files (.mjs)' + displayName: 'Binplace dist files' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js' diff --git a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml index 355a575307f0b..fb3ebdc760a7b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml @@ -89,7 +89,7 @@ jobs: # must call vsdevcmd first to add cmake to PATH - script: | python --version - python "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --use_vcpkg --use_vcpkg_ms_internal_asset_cache --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" + python "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml index a0f22fcfce14e..bb6c210161952 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml @@ -105,31 +105,3 @@ stages: onnxruntime_webgpu_external_dawn_test.exe --no_proc_table displayName: Run tests (onnxruntime_webgpu_external_dawn_test) workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - -- stage: webgpu_minimal_build_edge - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat - buildArch: x64 - additionalBuildFlags: >- - --build_shared_lib - --disable_exceptions - --disable_rtti - --enable_msvc_static_runtime - --enable_reduced_operator_type_support - --skip_tests - --use_binskim_compliant_compile_flags - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF onnxruntime_DISABLE_SPARSE_TENSORS=ON onnxruntime_DISABLE_OPTIONAL_TYPE=ON - --minimal_build extended - --use_webgpu - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_RelWithDebInfo - RunOnnxRuntimeTests: false - ORT_EP_NAME: WebGPU - EnablePython: false - WITH_CACHE: true - MachinePool: onnxruntime-Win2022-VS2022-webgpu-A10 diff --git a/tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml deleted file mode 100644 index f95ac526886fa..0000000000000 --- a/tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml +++ /dev/null @@ -1,116 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -jobs: -- job: 'BUILD_OPENVINO_EP' - pool: 'onnxruntime-Win-CPU-2022' - variables: - MsbuildArguments: '-detailedsummary -maxcpucount -consoleloggerparameters:PerformanceSummary' - OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' - DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true - buildArch: x64 - setVcvars: true - BuildConfig: 'RelWithDebInfo' - ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - timeoutInMinutes: 240 - workspace: - clean: all - steps: - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: $(buildArch) - - - template: templates/jobs/download_win_openvino.yml - - - powershell: | - Write-Output "Setting up OpenVINO environment variables" - . "$(OpenVINORootDir)\setupvars.ps1" - - Write-Output "Exporting selected environment variables to pipeline" - - $vars = @( - "INTEL_OPENVINO_DIR", - "OpenVINO_DIR", - "OpenVINOGenAI_DIR", - "OPENVINO_LIB_PATHS", - "TBB_DIR", - "PATH", - "PYTHONPATH" - ) - - foreach ($var in $vars) { - if (Test-Path "Env:$var") { - $value = [System.Environment]::GetEnvironmentVariable($var, "Process") - Write-Output "Setting $var" - Write-Output "##vso[task.setvariable variable=$var;]$value" - } else { - Write-Output "Warning: $var is not set." - } - } - - Write-Output "Selected environment variables exported successfully" - displayName: 'Set up OpenVINO environment' - - - template: templates/jobs/win-ci-build-steps.yml - parameters: - WithCache: True - Today: $(TODAY) - AdditionalKey: "win-openvino | $(BuildConfig)" - BuildPyArguments: >- - --config $(BuildConfig) - --build_dir $(Build.BinariesDirectory) - --cmake_generator "Visual Studio 17 2022" - --build_shared_lib - --use_openvino CPU - --use_binskim_compliant_compile_flags - --update --parallel - MsbuildArguments: $(MsbuildArguments) - BuildArch: $(buildArch) - Platform: 'x64' - BuildConfig: $(BuildConfig) - - - powershell: | - Write-Output "Getting CPU information" - Get-WmiObject Win32_Processor | Select-Object Name, NumberOfCores, NumberOfLogicalProcessors, Architecture | Format-Table -AutoSize - - Write-Output "Starting unit tests" - python "$(Build.SourcesDirectory)\tools\ci_build\build.py" ` - --config "$(BuildConfig)" ` - --build_dir "$(Build.BinariesDirectory)" ` - --cmake_generator "Visual Studio 17 2022" ` - --build_shared_lib ` - --use_openvino CPU ` - --use_binskim_compliant_compile_flags ` - --test --enable_onnx_tests - displayName: 'Run unit tests' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 1c3d911fa7dbb..e08d7eb2b12de 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -90,7 +90,7 @@ jobs: --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --cmake_generator "Visual Studio 17 2022" - --build_shared_lib --use_vcpkg --use_vcpkg_ms_internal_asset_cache + --build_shared_lib --use_qnn $(QnnLibKind) --qnn_home $(QnnSDKRootDir) --update --build --parallel diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index faef469e010f6..81de3335a07d2 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -78,7 +78,7 @@ jobs: --build_dir $(Build.BinariesDirectory) --cmake_generator "Visual Studio 17 2022" --build_java - --build_shared_lib --use_vcpkg --use_vcpkg_ms_internal_asset_cache + --build_shared_lib --use_qnn $(QnnLibKind) --qnn_home $(QnnSDKRootDir) --use_binskim_compliant_compile_flags diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index 899aaaa95216a..78f59452d1284 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -16,6 +16,8 @@ "android-x86_64-crosscompile-ci-pipeline.yml", "bigmodels-ci-pipeline.yml", "linux-ci-pipeline.yml", + "linux-cpu-aten-pipeline.yml", + "linux-cpu-eager-pipeline.yml", "linux-dnnl-ci-pipeline.yml", "linux-gpu-ci-pipeline.yml", "linux-gpu-tensorrt-ci-pipeline.yml", @@ -34,7 +36,6 @@ "win-gpu-doc-gen-ci-pipeline.yml", "win-gpu-tensorrt-ci-pipeline.yml", "win-gpu-webgpu-ci-pipeline.yml", - "win-openvino-ci-pipeline.yml", "win-qnn-arm64-ci-pipeline.yml", "win-qnn-ci-pipeline.yml", ] diff --git a/tools/nuget/generate_nuspec_for_custom_nuget.py b/tools/nuget/generate_nuspec_for_custom_nuget.py deleted file mode 100644 index baf46743cbf1b..0000000000000 --- a/tools/nuget/generate_nuspec_for_custom_nuget.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import argparse -import glob -import os -import shutil - -from generate_nuspec_for_native_nuget import generate_metadata - - -def generate_files(lines, args): - files_list = [""] - platform_map = { - "win-arm64": args.win_arm64, - "win-x64": args.win_x64, - } - - avoid_keywords = {"pdb"} - processed_includes = set() - for platform, platform_dir in platform_map.items(): - for file in glob.glob(os.path.join(platform_dir, "lib", "*")): - if not os.path.isfile(file): - continue - if any(keyword in file for keyword in avoid_keywords): - continue - file_name = os.path.basename(file) - - files_list.append(f'') - - for file in glob.glob(os.path.join(platform_dir, "include", "*")): - if not os.path.isfile(file): - continue - file_name = os.path.basename(file) - if file_name in processed_includes: - continue - processed_includes.add(file_name) - files_list.append(f'') - - files_list.append( - f'' - ) - - files_list.append(f'') - files_list.append( - f'' - ) - files_list.append(f'') - files_list.append( - f'' - ) - - source_props = os.path.join( - args.root_dir, - "csharp", - "src", - "Microsoft.ML.OnnxRuntime", - "targets", - "netstandard", - "props.xml", - ) - target_props = os.path.join( - args.root_dir, - "csharp", - "src", - "Microsoft.ML.OnnxRuntime", - "targets", - "netstandard", - f"{args.package_name}.props", - ) - shutil.copyfile(source_props, target_props) - files_list.append(f'') - files_list.append(f'') - - source_targets = os.path.join( - args.root_dir, - "csharp", - "src", - "Microsoft.ML.OnnxRuntime", - "targets", - "netstandard", - "targets.xml", - ) - target_targets = os.path.join( - args.root_dir, - "csharp", - "src", - "Microsoft.ML.OnnxRuntime", - "targets", - "netstandard", - f"{args.package_name}.targets", - ) - shutil.copyfile(source_targets, target_targets) - files_list.append(f'') - files_list.append(f'') - - files_list.append("") - lines.extend(files_list) - - -def parse_arguments(): - parser = argparse.ArgumentParser( - description="Create a nuspec file for the custom nuget package.", - ) - - parser.add_argument("--nuspec_path", required=True, help="Nuspec output file path.") - parser.add_argument("--root_dir", required=True, help="ORT repository root.") - parser.add_argument( - "--commit_id", - required=True, - help="The last commit id included in this package.", - ) - parser.add_argument("--win_arm64", required=True, help="Ort win-arm64 directory") - parser.add_argument("--win_x64", required=True, help="Ort win-x64 directory") - parser.add_argument("--package_version", required=True, help="Version of the package") - parser.add_argument("--package_name", required=True, help="Name of the package") - - args = parser.parse_args() - - args.sdk_info = "" - - return args - - -def generate_nuspec(args: argparse.Namespace): - lines = [''] - lines.append("") - - generate_metadata(lines, args) - generate_files(lines, args) - - lines.append("") - return lines - - -def main(): - args = parse_arguments() - - lines = generate_nuspec(args) - - with open(os.path.join(args.nuspec_path), "w") as f: - for line in lines: - # Uncomment the printing of the line if you need to debug what's produced on a CI machine - print(line) - f.write(line) - f.write("\n") - - -if __name__ == "__main__": - main() diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index aca5f1df7d18b..1546a9143831a 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -24,7 +24,6 @@ def get_pipeline_names(): "Windows GPU DML CI Pipeline", "Windows GPU Doc Gen CI Pipeline", "Windows GPU TensorRT CI Pipeline", - "Windows OpenVINO CI Pipeline", "ONNX Runtime Web CI Pipeline", "Win_TRT_Minimal_CUDA_Test_CI", # linux diff --git a/tools/python/util/__init__.py b/tools/python/util/__init__.py index 8631218ca9e00..a669963e84bcf 100644 --- a/tools/python/util/__init__.py +++ b/tools/python/util/__init__.py @@ -7,8 +7,7 @@ from .run import run # noqa: F401 from .vcpkg_helpers import ( # noqa: F401 generate_android_triplets, - generate_linux_triplets, - generate_macos_triplets, + generate_posix_triplets, generate_vcpkg_triplets_for_emscripten, generate_windows_triplets, ) diff --git a/tools/python/util/vcpkg_helpers.py b/tools/python/util/vcpkg_helpers.py index 875a6186e55c2..d33b2f7675690 100644 --- a/tools/python/util/vcpkg_helpers.py +++ b/tools/python/util/vcpkg_helpers.py @@ -222,7 +222,6 @@ def generate_triplet_for_posix_platform( enable_asan: bool, crt_linkage: str, target_abi: str, - osx_deployment_target: str, ) -> None: """ Generate triplet file for POSIX platforms (Linux, macOS). @@ -236,7 +235,6 @@ def generate_triplet_for_posix_platform( enable_asan (bool): Flag indicating if AddressSanitizer is enabled. crt_linkage (str): The CRT linkage type ("static" or "dynamic"). target_abi (str): The target ABI, which maps to the VCPKG_TARGET_ARCHITECTURE variable. Valid options include x86, x64, arm, arm64, arm64ec, s390x, ppc64le, riscv32, riscv64, loongarch32, loongarch64, mips64. - osx_deployment_target (str, optional): The macOS deployment target version. The parameter sets the minimum macOS version for compiled binaries. It also changes what versions of the macOS platform SDK CMake will search for. See the CMake documentation for CMAKE_OSX_DEPLOYMENT_TARGET for more information. """ folder_name_parts = [] if enable_asan: @@ -343,8 +341,6 @@ def generate_triplet_for_posix_platform( else: osx_abi = target_abi f.write(f'set(VCPKG_OSX_ARCHITECTURES "{osx_abi}")\n') - if osx_deployment_target: - f.write(f'set(VCPKG_OSX_DEPLOYMENT_TARGET "{osx_deployment_target}")\n') f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") f.write( "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" @@ -505,58 +501,32 @@ def generate_windows_triplets(build_dir: str) -> None: add_port_configs(f, enable_exception, False) -def generate_linux_triplets(build_dir: str) -> None: +def generate_posix_triplets(build_dir: str) -> None: """ - Generate triplet files for Linux platforms. + Generate triplet files for POSIX platforms (Linux, macOS). Args: build_dir (str): The directory to save the generated triplet files. """ - target_abis = ["x86", "x64", "arm", "arm64", "s390x", "ppc64le", "riscv64", "loongarch64", "mips64"] - for enable_rtti in [True, False]: - for enable_exception in [True, False]: - for enable_binskim in [True, False]: - for enable_asan in [True, False]: - if enable_asan and enable_binskim: - continue - for target_abi in target_abis: - generate_triplet_for_posix_platform( - build_dir, - "linux", - enable_rtti, - enable_exception, - enable_binskim, - enable_asan, - "dynamic", - target_abi, - None, - ) - - -def generate_macos_triplets(build_dir: str, osx_deployment_target: str) -> None: - """ - Generate triplet files for macOS platforms. - - Args: - build_dir (str): The directory to save the generated triplet files. - osx_deployment_target (str, optional): The macOS deployment target version. - """ - target_abis = ["x64", "arm64", "universal2"] - for enable_rtti in [True, False]: - for enable_exception in [True, False]: - for enable_binskim in [True, False]: - for enable_asan in [True, False]: - if enable_asan and enable_binskim: - continue - for target_abi in target_abis: - generate_triplet_for_posix_platform( - build_dir, - "osx", - enable_rtti, - enable_exception, - enable_binskim, - enable_asan, - "dynamic", - target_abi, - osx_deployment_target, - ) + for os_name in ["linux", "osx"]: + if os_name == "linux": + target_abis = ["x86", "x64", "arm", "arm64", "s390x", "ppc64le", "riscv64", "loongarch64", "mips64"] + else: + target_abis = ["x64", "arm64", "universal2"] + for enable_rtti in [True, False]: + for enable_exception in [True, False]: + for enable_binskim in [True, False]: + for enable_asan in [True, False]: + if enable_asan and enable_binskim: + continue + for target_abi in target_abis: + generate_triplet_for_posix_platform( + build_dir, + os_name, + enable_rtti, + enable_exception, + enable_binskim, + enable_asan, + "dynamic", + target_abi, + ) diff --git a/winml/adapter/winml_adapter_model.cpp b/winml/adapter/winml_adapter_model.cpp index cf02c6fa2328b..195bf6e5f0ffd 100644 --- a/winml/adapter/winml_adapter_model.cpp +++ b/winml/adapter/winml_adapter_model.cpp @@ -593,13 +593,13 @@ ORT_API_STATUS_IMPL( input.set_name(input_name); if (info->type == ONNXType::ONNX_TYPE_TENSOR) { - auto num_dims = info->tensor_type_info->shape.NumDimensions(); + auto num_dims = info->data->shape.NumDimensions(); CreateTypeProto_Tensor( input.mutable_type()->mutable_tensor_type(), input_name, - (num_dims == 0) ? nullptr : &info->tensor_type_info->shape[0], + (num_dims == 0) ? nullptr : &info->data->shape[0], num_dims, - ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type) + ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type) ); } return nullptr; @@ -619,12 +619,12 @@ ORT_API_STATUS_IMPL( ONNX_NAMESPACE::TensorProto& input = *graph.add_initializer(); input.set_name(input_name); - auto num_dims = info->tensor_type_info->shape.NumDimensions(); + auto num_dims = info->data->shape.NumDimensions(); for (size_t i = 0; i < num_dims; i++) { - input.add_dims(info->tensor_type_info->shape[i]); + input.add_dims(info->data->shape[i]); } - input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type)); + input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)); auto tensor = value->GetMutable(); input.set_raw_data(tensor->DataRaw(), tensor->SizeInBytes()); @@ -645,9 +645,9 @@ ORT_API_STATUS_IMPL( CreateTypeProto_Tensor( output.mutable_type()->mutable_tensor_type(), output_name, - &info->tensor_type_info->shape[0], - info->tensor_type_info->shape.NumDimensions(), - ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type) + &info->data->shape[0], + info->data->shape.NumDimensions(), + ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type) ); } return nullptr; From 788fc788fa77c5aa4d656dd841caff610b1cda08 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 Date: Mon, 10 Mar 2025 14:22:46 +0530 Subject: [PATCH 010/138] Revert "Merge pull request #607" This reverts commit 920ed58654c3fe3197334e16a10f53fd972051ff, reversing changes made to a6cdf62176c116e3a1e07f6cec1681c041d653b9. --- ThirdPartyNotices.txt | 35 + cmake/deps.txt | 1 - .../external/onnxruntime_external_deps.cmake | 54 +- cmake/nuget_helpers.cmake | 2 +- cmake/onnxruntime_framework.cmake | 5 +- cmake/onnxruntime_optimizer.cmake | 1 + cmake/onnxruntime_providers_js.cmake | 6 +- cmake/onnxruntime_python.cmake | 2 +- cmake/onnxruntime_session.cmake | 1 + cmake/onnxruntime_unittests.cmake | 43 +- cmake/onnxruntime_webassembly.cmake | 37 +- cmake/patches/dawn/dawn.patch | 113 ++- cmake/winml_sdk_helpers.cmake | 2 +- ...oft.ML.OnnxRuntime.FasterRcnnSample.csproj | 2 +- .../ManagedProjections.shared.cs | 3 +- .../NativeMethods.shared.cs | 4 +- .../core/framework/execution_provider.h | 16 + include/onnxruntime/core/graph/graph.h | 32 +- include/onnxruntime/core/graph/graph_viewer.h | 6 + .../core/graph/indexed_sub_graph.h | 6 + .../core/session/onnxruntime_c_api.h | 491 +++++++++++- .../core/session/onnxruntime_cxx_api.h | 261 ++++++- .../core/session/onnxruntime_cxx_inline.h | 350 ++++++++- .../onnxruntime_session_options_config_keys.h | 5 +- js/build_webgpu.bat | 79 ++ js/common/lib/tensor-impl-type-mapping.ts | 9 +- js/common/lib/tensor-impl.ts | 7 + js/common/package.json | 3 +- js/common/test/unit-tests/common.ts | 5 +- .../test/unit-tests/tensor/constructor-f16.ts | 62 ++ .../unit-tests/tensor/constructor-type.ts | 8 - js/web/lib/build-def.d.ts | 7 + js/web/lib/wasm/jsep/backend-webgpu.ts | 28 +- js/web/lib/wasm/jsep/backend-webnn.ts | 3 +- js/web/lib/wasm/jsep/init.ts | 144 ++-- .../ops/3rd-party/conv_backprop_webgpu.ts | 96 ++- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 2 +- .../lib/wasm/jsep/webgpu/program-manager.ts | 1 - js/web/lib/wasm/jsep/webgpu/types.ts | 10 - js/web/lib/wasm/proxy-wrapper.ts | 8 +- js/web/lib/wasm/session-options.ts | 116 ++- js/web/lib/wasm/wasm-core-impl.ts | 97 ++- js/web/lib/wasm/wasm-types.ts | 68 +- js/web/lib/wasm/wasm-utils-import.ts | 50 +- js/web/script/build.ts | 36 +- js/web/test/data/ops/conv-transpose.jsonc | 122 +++ js/web/test/e2e/exports/main.js | 11 +- js/web/test/e2e/exports/test.js | 22 + .../contrib_ops/webgpu/bert/bias_add.cc | 80 ++ .../contrib_ops/webgpu/bert/bias_add.h | 32 + .../contrib_ops/webgpu/bert/fast_gelu.cc | 4 +- .../webgpu/bert/flash_attention.cc | 4 +- .../webgpu/bert/rotary_embedding.cc | 14 +- .../webgpu/bert/skip_layer_norm.cc | 4 +- .../webgpu/quantization/dp4a_matmul_nbits.cc | 326 ++++++++ .../webgpu/quantization/dp4a_matmul_nbits.h | 56 ++ .../webgpu/quantization/matmul_nbits.cc | 322 +------- .../webgpu/quantization/matmul_nbits.h | 19 - .../subgroup_matrix_matmul_nbits.cc | 8 +- .../webgpu/webgpu_contrib_kernels.cc | 4 +- .../core/framework/compute_capability.h | 20 + .../core/framework/execution_provider.cc | 1 + .../core/framework/external_data_loader.cc | 7 +- .../core/framework/external_data_loader.h | 2 +- .../core/framework/fallback_cpu_capability.cc | 4 + .../core/framework/fallback_cpu_capability.h | 4 + .../core/framework/graph_partitioner.cc | 248 ++++--- .../core/framework/graph_partitioner.h | 9 +- .../core/framework/onnxruntime_typeinfo.cc | 71 +- .../core/framework/onnxruntime_typeinfo.h | 2 +- .../core/framework/session_state_utils.cc | 35 +- .../core/framework/tensor_type_and_shape.cc | 35 +- .../core/framework/tensorprotoutils.cc | 29 +- onnxruntime/core/framework/tensorprotoutils.h | 10 +- onnxruntime/core/graph/graph.cc | 295 +++++++- .../core/graph/graph_flatbuffers_utils.cc | 14 +- onnxruntime/core/graph/model.cc | 32 +- onnxruntime/core/graph/model.h | 8 +- .../core/graph/model_editor_api_types.h | 47 ++ .../core/optimizer/constant_folding.cc | 13 +- onnxruntime/core/optimizer/constant_folding.h | 18 + .../optimizer/graph_optimizer_registry.cc | 49 ++ .../core/optimizer/graph_optimizer_registry.h | 77 ++ .../constant_folding_dq_node.cc | 26 + .../constant_folding_dq_node.h | 37 + .../selection_and_optimization_func.cc | 99 +++ .../selection_and_optimization_func.h | 31 + .../providers/acl/acl_execution_provider.cc | 1 + .../providers/acl/acl_execution_provider.h | 1 + .../providers/cann/cann_execution_provider.cc | 1 + .../providers/cann/cann_execution_provider.h | 1 + .../coreml/coreml_execution_provider.cc | 1 + .../coreml/coreml_execution_provider.h | 1 + .../core/providers/cpu/controlflow/loop.cc | 4 +- .../cpu/quantization/conv_integer.cc | 7 +- .../core/providers/cuda/controlflow/loop.cc | 4 +- .../providers/cuda/cuda_execution_provider.cc | 1 + .../providers/cuda/cuda_execution_provider.h | 1 + .../core/providers/cuda/tensor/upsample.cc | 20 +- .../providers/cuda/tensor/upsample_impl.cu | 94 +-- .../providers/cuda/tensor/upsample_impl.h | 20 +- .../src/ExecutionProvider.cpp | 6 +- .../src/ExecutionProvider.h | 3 + .../providers/dnnl/dnnl_execution_provider.cc | 1 + .../providers/dnnl/dnnl_execution_provider.h | 1 + .../providers/js/js_execution_provider.cc | 1 + .../core/providers/js/js_execution_provider.h | 1 + .../migraphx/migraphx_execution_provider.cc | 1 + .../migraphx/migraphx_execution_provider.h | 1 + .../nnapi_builtin/nnapi_execution_provider.cc | 1 + .../nnapi_builtin/nnapi_execution_provider.h | 1 + .../openvino/backends/basic_backend.cc | 2 +- .../openvino/openvino_execution_provider.cc | 1 + .../openvino/openvino_execution_provider.h | 1 + .../qnn/builder/onnx_ctx_model_helper.cc | 38 +- .../qnn/builder/onnx_ctx_model_helper.h | 7 +- .../qnn/builder/qnn_backend_manager.cc | 2 + .../core/providers/qnn/qnn_allocator.cc | 4 +- .../providers/qnn/qnn_execution_provider.cc | 73 +- .../providers/qnn/qnn_execution_provider.h | 2 + .../core/providers/qnn/shared_context.h | 26 + .../rknpu/rknpu_execution_provider.cc | 1 + .../rknpu/rknpu_execution_provider.h | 1 + .../providers/rocm/rocm_execution_provider.cc | 1 + .../providers/rocm/rocm_execution_provider.h | 1 + .../providers/shared_library/provider_api.h | 1 + .../provider_bridge_provider.cc | 3 +- .../shared_library/provider_interfaces.h | 9 + .../shared_library/provider_wrappedtypes.h | 3 + .../providers/snpe/snpe_execution_provider.cc | 1 + .../providers/snpe/snpe_execution_provider.h | 1 + .../tensorrt/tensorrt_execution_provider.cc | 55 +- .../tensorrt/tensorrt_execution_provider.h | 31 + .../tensorrt_execution_provider_helper.cc | 129 ++++ .../vitisai/vitisai_execution_provider.cc | 2 +- .../vitisai/vitisai_execution_provider.h | 1 + .../vsinpu/vsinpu_execution_provider.cc | 1 + .../vsinpu/vsinpu_execution_provider.h | 1 + .../providers/webgpu/external_data_loader.cc | 40 + .../providers/webgpu/external_data_loader.h | 30 + .../core/providers/webgpu/generator/range.cc | 2 +- .../webgpu/math/binary_elementwise_ops.cc | 2 +- .../core/providers/webgpu/math/softmax.cc | 238 ++++++ .../core/providers/webgpu/math/softmax.h | 54 ++ .../webgpu/math/unary_elementwise_ops.cc | 2 +- .../core/providers/webgpu/nn/layer_norm.cc | 6 +- onnxruntime/core/providers/webgpu/program.cc | 20 + onnxruntime/core/providers/webgpu/program.h | 1 + .../core/providers/webgpu/program_manager.cc | 10 +- .../webgpu/reduction/reduction_ops.cc | 168 +++++ .../webgpu/reduction/reduction_ops.h | 62 ++ .../core/providers/webgpu/shader_helper.cc | 3 - .../core/providers/webgpu/shader_variable.cc | 2 +- .../core/providers/webgpu/tensor/cast.cc | 2 +- .../core/providers/webgpu/tensor/cast.h | 2 +- .../core/providers/webgpu/tensor/concat.cc | 2 +- .../core/providers/webgpu/tensor/expand.cc | 2 +- .../core/providers/webgpu/tensor/gather.cc | 2 +- .../core/providers/webgpu/tensor/pad.cc | 261 +++++++ .../core/providers/webgpu/tensor/pad.h | 40 + .../providers/webgpu/tensor/resize_impl.cc | 8 +- .../core/providers/webgpu/tensor/split.cc | 6 +- .../core/providers/webgpu/tensor/transpose.cc | 62 +- .../core/providers/webgpu/tensor/transpose.h | 2 + .../core/providers/webgpu/tensor/where.cc | 2 +- .../core/providers/webgpu/webgpu_context.cc | 61 +- .../webgpu/webgpu_execution_provider.cc | 38 +- .../webgpu/webgpu_execution_provider.h | 4 + .../webgpu/webgpu_pix_frame_generator.cc | 4 +- .../webgpu/webgpu_pix_frame_generator.h | 2 +- .../webgpu/webgpu_provider_factory.cc | 6 + .../impl/rotaryEmbedding_op_builder.cc | 14 +- .../providers/webnn/builders/model_builder.cc | 6 +- .../providers/webnn/builders/model_builder.h | 10 +- .../webnn/webnn_execution_provider.cc | 1 + .../webnn/webnn_execution_provider.h | 1 + .../xnnpack/xnnpack_execution_provider.cc | 1 + .../xnnpack/xnnpack_execution_provider.h | 1 + .../core/session/abi_session_options.cc | 17 +- onnxruntime/core/session/api_utils.cc | 25 - onnxruntime/core/session/api_utils.h | 9 - onnxruntime/core/session/custom_ops.cc | 25 +- onnxruntime/core/session/inference_session.cc | 78 +- onnxruntime/core/session/inference_session.h | 35 +- onnxruntime/core/session/model_editor_api.h | 65 ++ .../core/session/model_editor_c_api.cc | 358 +++++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 328 ++++---- onnxruntime/core/session/ort_apis.h | 16 + .../core/session/provider_bridge_ort.cc | 23 +- onnxruntime/core/session/utils.cc | 125 ++++ onnxruntime/core/session/utils.h | 28 + .../execution_providers/qnn/quant_config.py | 6 +- .../python/tools/quantization/quantize.py | 32 +- .../tools/transformers/models/sam2/README.md | 31 +- .../models/sam2/benchmark_sam2.py | 15 +- .../models/sam2/benchmark_sam2.sh | 310 +++++--- .../models/sam2/convert_to_onnx.py | 14 +- .../transformers/models/sam2/image_decoder.py | 2 +- .../transformers/models/sam2/image_encoder.py | 74 +- .../transformers/models/sam2/mask_decoder.py | 2 +- .../models/sam2/prompt_encoder.py | 2 +- .../README.md | 10 +- .../command_args_parser.cc | 47 +- .../command_args_parser.h | 0 .../test/ep_weight_sharing_ctx_gen/main.cc | 247 ++++++ .../test_configuration.h | 7 +- .../test/framework/inference_session_test.cc | 1 + .../test/framework/session_state_test.cc | 27 +- onnxruntime/test/framework/type_info_test.cc | 26 +- onnxruntime/test/providers/base_tester.cc | 6 +- onnxruntime/test/providers/base_tester.h | 6 +- .../test/providers/cpu/math/softmax_test.cc | 13 +- .../providers/cpu/nn/conv_integer_test.cc | 40 + .../internal_testing_execution_provider.cc | 1 + .../internal_testing_execution_provider.h | 1 + .../test/providers/qnn/qnn_ep_context_test.cc | 267 ++++--- .../test/providers/qnn/qnn_test_utils.cc | 7 +- .../quantization/test_get_qdq_config.py | 56 ++ onnxruntime/test/qnn_ctx_gen/main.cc | 250 ------- .../test/shared_lib/custom_op_utils.cc | 20 + onnxruntime/test/shared_lib/custom_op_utils.h | 67 +- onnxruntime/test/shared_lib/test_inference.cc | 192 +++-- .../test/shared_lib/test_model_builder_api.cc | 701 ++++++++++++++++++ .../test/shared_lib/test_ort_format_models.cc | 14 +- onnxruntime/test/shared_lib/utils.h | 52 ++ .../test/testdata/cast_float_to_double.onnx | Bin 0 -> 136 bytes .../my_execution_provider.cc | 2 +- .../my_execution_provider.h | 2 +- onnxruntime/wasm/api.cc | 26 +- onnxruntime/wasm/api.h | 24 +- onnxruntime/wasm/js_post_js.js | 2 - onnxruntime/wasm/js_post_js_64.js | 2 - onnxruntime/wasm/post-webgpu.js | 261 +++++++ onnxruntime/wasm/pre-async.js | 132 ++++ onnxruntime/wasm/pre-jsep.js | 308 +++----- onnxruntime/wasm/pre.js | 15 +- setup.py | 2 +- tools/ci_build/build.py | 21 +- .../custom-nuget-packaging-pipeline.yml | 142 ++++ .../py-package-test-pipeline.yml | 2 + .../azure-pipelines/py-packaging-pipeline.yml | 50 +- .../qnn-ep-nuget-packaging-pipeline.yml | 148 ++-- .../rocm-nuget-packaging-pipeline.yml | 339 --------- .../rocm-publish-nuget-pipeline.yml | 21 - .../stages/nuget-cuda-packaging-stage.yml | 15 +- .../stages/nuget-qnn-packaging-stage.yml | 76 ++ .../stages/py-cpu-packaging-stage.yml | 124 ++-- ...acts-package-and-publish-steps-windows.yml | 16 + .../templates/jobs/download_win_openvino.yml | 64 ++ .../templates/linux-web-init-and-check.yml | 8 + .../templates/py-linux-qnn.yml | 118 +-- .../azure-pipelines/templates/py-linux.yml | 144 ++-- .../templates/py-package-smoking-test.yml | 13 +- .../templates/py-packaging-linux-test-cpu.yml | 18 +- .../templates/py-win-arm64-qnn.yml | 273 +++---- .../templates/py-win-arm64ec-qnn.yml | 241 +++--- .../templates/py-win-x64-qnn.yml | 21 +- .../azure-pipelines/templates/qnn-ep-win.yml | 260 ++++--- .../templates/react-native-ci.yml | 12 +- .../azure-pipelines/templates/web-ci.yml | 3 - .../azure-pipelines/templates/win-ci.yml | 2 +- .../azure-pipelines/templates/win-web-ci.yml | 12 +- .../templates/win-web-multi-browsers.yml | 12 +- .../templates/windowsai-steps.yml | 2 +- .../win-gpu-webgpu-ci-pipeline.yml | 28 + .../win-openvino-ci-pipeline.yml | 116 +++ .../win-qnn-arm64-ci-pipeline.yml | 2 +- .../azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- tools/ci_build/set-trigger-rules.py | 3 +- .../nuget/generate_nuspec_for_custom_nuget.py | 150 ++++ tools/python/run_CIs_for_external_pr.py | 1 + tools/python/util/__init__.py | 3 +- tools/python/util/vcpkg_helpers.py | 78 +- winml/adapter/winml_adapter_model.cpp | 18 +- 274 files changed, 10047 insertions(+), 3285 deletions(-) create mode 100644 js/build_webgpu.bat create mode 100644 js/common/test/unit-tests/tensor/constructor-f16.ts create mode 100644 onnxruntime/contrib_ops/webgpu/bert/bias_add.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/bias_add.h create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h create mode 100644 onnxruntime/core/graph/model_editor_api_types.h create mode 100644 onnxruntime/core/optimizer/graph_optimizer_registry.cc create mode 100644 onnxruntime/core/optimizer/graph_optimizer_registry.h create mode 100644 onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc create mode 100644 onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h create mode 100644 onnxruntime/core/optimizer/selection_and_optimization_func.cc create mode 100644 onnxruntime/core/optimizer/selection_and_optimization_func.h create mode 100644 onnxruntime/core/providers/webgpu/external_data_loader.cc create mode 100644 onnxruntime/core/providers/webgpu/external_data_loader.h create mode 100644 onnxruntime/core/providers/webgpu/math/softmax.cc create mode 100644 onnxruntime/core/providers/webgpu/math/softmax.h create mode 100644 onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc create mode 100644 onnxruntime/core/providers/webgpu/reduction/reduction_ops.h create mode 100644 onnxruntime/core/providers/webgpu/tensor/pad.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/pad.h delete mode 100644 onnxruntime/core/session/api_utils.cc delete mode 100644 onnxruntime/core/session/api_utils.h create mode 100644 onnxruntime/core/session/model_editor_api.h create mode 100644 onnxruntime/core/session/model_editor_c_api.cc create mode 100644 onnxruntime/core/session/utils.cc create mode 100644 onnxruntime/core/session/utils.h rename onnxruntime/test/{qnn_ctx_gen => ep_weight_sharing_ctx_gen}/README.md (82%) rename onnxruntime/test/{qnn_ctx_gen => ep_weight_sharing_ctx_gen}/command_args_parser.cc (68%) rename onnxruntime/test/{qnn_ctx_gen => ep_weight_sharing_ctx_gen}/command_args_parser.h (100%) create mode 100644 onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc rename onnxruntime/test/{qnn_ctx_gen => ep_weight_sharing_ctx_gen}/test_configuration.h (75%) delete mode 100644 onnxruntime/test/qnn_ctx_gen/main.cc create mode 100644 onnxruntime/test/shared_lib/test_model_builder_api.cc create mode 100644 onnxruntime/test/testdata/cast_float_to_double.onnx create mode 100644 onnxruntime/wasm/post-webgpu.js create mode 100644 onnxruntime/wasm/pre-async.js create mode 100644 tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml delete mode 100644 tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml delete mode 100644 tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml create mode 100644 tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml create mode 100644 tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml create mode 100644 tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml create mode 100644 tools/nuget/generate_nuspec_for_custom_nuget.py diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 26084ab42ec1c..a449e42f6bf19 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -6045,3 +6045,38 @@ https://github.com/intel/neural-speed terms, and open source software license terms. These separate license terms govern your use of the third party programs as set forth in the "THIRD-PARTY-PROGRAMS" file. + +_____ + +dawn + +https://dawn.googlesource.com/dawn + + BSD 3-Clause License + + Copyright 2017-2023 The Dawn & Tint Authors + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/cmake/deps.txt b/cmake/deps.txt index d0bab93d3c16f..c7db8ef51505d 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -53,7 +53,6 @@ re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cd safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.zip;e49b2b964163d27765a5002d210a2f3c73771835 -utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0c12f53da76d0c31b03b9f0f8ec8f3b4.zip;239063aee4946a9af147b473a4c3da78ba7413b4 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index ebf20ab21bbd2..a477d6edb3a3f 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -107,23 +107,6 @@ if(onnxruntime_USE_MIMALLOC) FetchContent_MakeAvailable(mimalloc) endif() -#Protobuf depends on utf8_range -onnxruntime_fetchcontent_declare( - utf8_range - URL ${DEP_URL_utf8_range} - URL_HASH SHA1=${DEP_SHA1_utf8_range} - EXCLUDE_FROM_ALL - FIND_PACKAGE_ARGS NAMES utf8_range -) - -set(utf8_range_ENABLE_TESTS OFF CACHE BOOL "Build test suite" FORCE) -set(utf8_range_ENABLE_INSTALL OFF CACHE BOOL "Configure installation" FORCE) - -# The next line will generate an error message "fatal: not a git repository", but it is ok. It is from flatbuffers -onnxruntime_fetchcontent_makeavailable(utf8_range) -# protobuf's cmake/utf8_range.cmake has the following line -include_directories(${utf8_range_SOURCE_DIR}) - # Download a protoc binary from Internet if needed if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE AND NOT onnxruntime_USE_VCPKG) # This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually @@ -304,7 +287,7 @@ if(NOT TARGET Boost::mp11) EXCLUDE_FROM_ALL FIND_PACKAGE_ARGS NAMES Boost ) - onnxruntime_fetchcontent_makeavailable(mp11) + onnxruntime_fetchcontent_makeavailable(mp11) if(NOT TARGET Boost::mp11) add_library(Boost::mp11 ALIAS Boost::headers) endif() @@ -442,6 +425,9 @@ target_include_directories(safeint_interface INTERFACE ${safeint_SOURCE_DIR}) # Flatbuffers +if(onnxruntime_USE_VCPKG) + find_package(flatbuffers REQUIRED) +else() # We do not need to build flatc for iOS or Android Cross Compile if (CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(FLATBUFFERS_BUILD_FLATC OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATC" FORCE) @@ -492,6 +478,7 @@ namespace std { using ::getenv; } endif() endif() endif() +endif() # ONNX if (NOT onnxruntime_USE_FULL_PROTOBUF) @@ -672,17 +659,10 @@ if (onnxruntime_USE_WEBGPU) # disable things we don't use set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) - set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE) - set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE) - set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE) - set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE) set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE) set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE) set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE) set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled. @@ -732,7 +712,29 @@ if (onnxruntime_USE_WEBGPU) # # if we need to apply patches in the future, we can uncomment the following line. # # The dawn.patch contains the following changes: - # - https://dawn-review.googlesource.com/c/dawn/+/225514 + # + # - (public) CMake fix to support Emscripten v4.0.3+ + # This change allows Dawn to find the file "gen_struct_info.py" in the correct location. + # https://dawn-review.googlesource.com/c/dawn/+/225514 + # + # - (public) Fix emwgpu C++ implementation for buffer destroy + # In native implementation, wgpuBufferRelease will trigger the buffer destroy (if refcount decreased to 0). But + # in emwgpu implementation, the buffer destroy won't happen. This change fixes the bug. + # https://dawn-review.googlesource.com/c/dawn/+/226315 + # + # - (private) Allow "external" buffer in emwgpu C++ implementation + # This change allows WGPUBufferImpl to destroy the buffer when the refcount decreased to 0 only for non-external + # buffer. + # "external buffer" means the GPUBuffer instance created in JavaScript and imported to C++ by `importJsBuffer`. + # + # - (private) Remove hard-coded CMAKE_OSX_DEPLOYMENT_TARGET in Dawn's CMake files + # https://github.com/microsoft/onnxruntime/pull/23729 + # + # - (private) Fix external ref count for "external" device in emwgpu C++ implementation + # This change fixes the incorrect external ref count for class WGPUDeviceImpl when used with "external" device. + # "external device" means the GPUDevice instance created in JavaScript and imported to C++ by `importJsDevice`. + # + # PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch EXCLUDE_FROM_ALL ) diff --git a/cmake/nuget_helpers.cmake b/cmake/nuget_helpers.cmake index 22143ac422e9f..b066d1e9fb50e 100644 --- a/cmake/nuget_helpers.cmake +++ b/cmake/nuget_helpers.cmake @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -cmake_minimum_required(VERSION 3.0) +cmake_minimum_required(VERSION 3.5) # Determines the version of a native nuget package from the root packages.config. # diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index b1e98a9e0411c..9c9a25f8ee77e 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -36,10 +36,7 @@ elseif(onnxruntime_ENABLE_TRITON) endif() if (onnxruntime_MINIMAL_BUILD) - set(onnxruntime_framework_src_exclude - "${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.h" - "${ONNXRUNTIME_ROOT}/core/framework/fallback_cpu_capability.cc" - ) + set(onnxruntime_framework_src_exclude) # custom ops support must be explicitly enabled in a minimal build. exclude if not. if (NOT onnxruntime_MINIMAL_BUILD_CUSTOM_OPS) diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 9d680cd04af10..173c872d4cc06 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -9,6 +9,7 @@ if (onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_optimizer_src_patterns "${ONNXRUNTIME_INCLUDE_DIR}/core/optimizer/graph_transformer.h" "${ONNXRUNTIME_ROOT}/core/optimizer/graph_transformer.cc" + "${ONNXRUNTIME_ROOT}/core/optimizer/graph_optimizer_registry.cc" ) if (onnxruntime_EXTENDED_MINIMAL_BUILD) diff --git a/cmake/onnxruntime_providers_js.cmake b/cmake/onnxruntime_providers_js.cmake index 9811eae611463..fefbab5082da4 100644 --- a/cmake/onnxruntime_providers_js.cmake +++ b/cmake/onnxruntime_providers_js.cmake @@ -1,6 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. + if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "JSEP can not be used in a basic minimal build. Please build with '--minimal_build extended'") + endif() + add_compile_definitions(USE_JSEP=1) file(GLOB_RECURSE onnxruntime_providers_js_cc_srcs @@ -18,4 +22,4 @@ onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11 Eigen3::Eigen ) - add_dependencies(onnxruntime_providers_js ${onnxruntime_EXTERNAL_DEPENDENCIES}) \ No newline at end of file + add_dependencies(onnxruntime_providers_js ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index aee6d2ff7655c..64b53c2912be0 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -1029,7 +1029,7 @@ if (onnxruntime_USE_QNN) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy - $ + $ $/onnxruntime/capi/ ) if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 3d63285d50e72..2c2c59091fae5 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -22,6 +22,7 @@ endif() if (onnxruntime_MINIMAL_BUILD) set(onnxruntime_session_src_exclude "${ONNXRUNTIME_ROOT}/core/session/provider_bridge_ort.cc" + "${ONNXRUNTIME_ROOT}/core/session/model_builder_c_api.cc" ) list(REMOVE_ITEM onnxruntime_session_srcs ${onnxruntime_session_src_exclude}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0916aeb3dd92c..87aee2a174fab 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -236,14 +236,14 @@ function(AddTest) ) endif() # Set test timeout to 3 hours. - set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 7200) + set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 10800) else() add_test(NAME ${_UT_TARGET} COMMAND ${_UT_TARGET} ${TEST_ARGS} WORKING_DIRECTORY $ ) # Set test timeout to 3 hours. - set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 7200) + set_tests_properties(${_UT_TARGET} PROPERTIES TIMEOUT 10800) endif() endif() endfunction(AddTest) @@ -503,6 +503,7 @@ set (onnxruntime_shared_lib_test_SRC if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc) + list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_builder_api.cc) endif() if(onnxruntime_RUN_ONNX_TESTS) @@ -1288,31 +1289,34 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if(onnxruntime_USE_QNN) #qnn ctx generator - set(onnxruntime_qnn_ctx_gen_src_dir ${TEST_SRC_DIR}/qnn_ctx_gen) - set(onnxruntime_qnn_ctx_gen_src_patterns - "${onnxruntime_qnn_ctx_gen_src_dir}/*.cc" - "${onnxruntime_qnn_ctx_gen_src_dir}/*.h") + set(ep_weight_sharing_ctx_gen_src_dir ${TEST_SRC_DIR}/ep_weight_sharing_ctx_gen) + set(ep_weight_sharing_ctx_gen_src_patterns + "${ep_weight_sharing_ctx_gen_src_dir}/*.cc" + "${ep_weight_sharing_ctx_gen_src_dir}/*.h") - file(GLOB onnxruntime_qnn_ctx_gen_src CONFIGURE_DEPENDS - ${onnxruntime_qnn_ctx_gen_src_patterns} + file(GLOB ep_weight_sharing_ctx_gen_src CONFIGURE_DEPENDS + ${ep_weight_sharing_ctx_gen_src_patterns} ) - onnxruntime_add_executable(onnxruntime_qnn_ctx_gen ${onnxruntime_qnn_ctx_gen_src}) - target_include_directories(onnxruntime_qnn_ctx_gen PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} - ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} - ${CMAKE_CURRENT_BINARY_DIR}) + onnxruntime_add_executable(ep_weight_sharing_ctx_gen ${ep_weight_sharing_ctx_gen_src}) + target_include_directories(ep_weight_sharing_ctx_gen PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) if (WIN32) - target_compile_options(onnxruntime_qnn_ctx_gen PRIVATE ${disabled_warnings}) + target_compile_options(ep_weight_sharing_ctx_gen PRIVATE ${disabled_warnings}) if (NOT DEFINED SYS_PATH_LIB) set(SYS_PATH_LIB shlwapi) endif() endif() - if(WIN32) - target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE debug dbghelp advapi32) + if (onnxruntime_BUILD_SHARED_LIB) + set(ep_weight_sharing_ctx_gen_libs onnxruntime_common onnxruntime ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE}) + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE ${ep_weight_sharing_ctx_gen_libs}) + if (WIN32) + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE debug dbghelp advapi32) + endif() + else() + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE onnxruntime_session ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE}) endif() - target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) - set_target_properties(onnxruntime_qnn_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") + set_target_properties(ep_weight_sharing_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") endif() # shared lib @@ -1359,14 +1363,19 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) + + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_ROOT}) + if (onnxruntime_USE_CUDA) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) endif() + if (onnxruntime_USE_ROCM) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 8106e46ccf580..f3afaf7033fd1 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -211,10 +211,14 @@ else() target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard) endif() + set(onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre.js") + + set(EXPORTED_FUNCTIONS "_malloc,_free") if (onnxruntime_USE_JSEP) - set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName") - else() - set(EXPORTED_FUNCTIONS "_malloc,_free") + string(APPEND EXPORTED_FUNCTIONS ",_JsepOutput,_JsepGetNodeName") + endif() + if (onnxruntime_USE_WEBGPU) + string(APPEND EXPORTED_FUNCTIONS ",_wgpuBufferRelease,_wgpuCreateInstance") endif() if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) @@ -312,13 +316,15 @@ else() target_compile_options(noexcep_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental) endif() target_link_options(onnxruntime_webassembly PRIVATE - --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js" + "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js\"" ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js") else () set(MAXIMUM_MEMORY "4294967296") target_link_options(onnxruntime_webassembly PRIVATE - --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js" + "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/js_post_js.js\"" ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js") endif () target_link_options(onnxruntime_webassembly PRIVATE @@ -372,7 +378,6 @@ jsepDownload:_pp_") "SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'" ) endif () - set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) if (onnxruntime_USE_JSEP) # NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU @@ -382,10 +387,8 @@ jsepDownload:_pp_") target_compile_definitions(onnxruntime_webassembly PRIVATE USE_JSEP=1) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" - "SHELL:-s ASYNCIFY=1" - "SHELL:-s ASYNCIFY_STACK_SIZE=65536" ) - set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js") if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) target_link_options(onnxruntime_webassembly PRIVATE @@ -397,6 +400,20 @@ jsepDownload:_pp_") if (onnxruntime_USE_WEBGPU) target_compile_definitions(onnxruntime_webassembly PRIVATE USE_WEBGPU=1) + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/post-webgpu.js\"" + ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/post-webgpu.js") + endif() + + if (onnxruntime_USE_JSEP OR onnxruntime_USE_WEBGPU OR onnxruntime_USE_WEBNN) + # if any of the above is enabled, we need to use the asyncify library + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-async.js\"" + "SHELL:-s ASYNCIFY=1" + "SHELL:-s ASYNCIFY_STACK_SIZE=65536" + ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-async.js") endif() if (onnxruntime_EMSCRIPTEN_SETTINGS) @@ -458,6 +475,8 @@ jsepDownload:_pp_") ) endif() + set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS "${onnxruntime_webassembly_script_deps}") + set(target_name_list ort) if (onnxruntime_ENABLE_TRAINING_APIS) diff --git a/cmake/patches/dawn/dawn.patch b/cmake/patches/dawn/dawn.patch index 2f85d5ab473b5..b578b858eac59 100644 --- a/cmake/patches/dawn/dawn.patch +++ b/cmake/patches/dawn/dawn.patch @@ -18,7 +18,7 @@ index 6e8ae37593..633af91eef 100644 @@ -77,9 +77,17 @@ if (${DAWN_ENABLE_EMSCRIPTEN}) "${arg_UNPARSED_ARGUMENTS}") endif() - + + # since Emscripten 4.0.3, file gen_struct_info.py is moved to outside of directory maint. + if (EXISTS "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/gen_struct_info.py") + set(EM_GEN_STRUCT_INFO_SCRIPT "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/gen_struct_info.py") @@ -34,3 +34,114 @@ index 6e8ae37593..633af91eef 100644 -q "${EM_BUILD_GEN_DIR}/struct_info_webgpu.json" "-I=${EM_BUILD_GEN_DIR}/include" +diff --git a/src/emdawnwebgpu/README.md b/src/emdawnwebgpu/README.md +index efd6491cd6..8ebc5d28b6 100644 +--- a/src/emdawnwebgpu/README.md ++++ b/src/emdawnwebgpu/README.md +@@ -56,7 +56,7 @@ Set up the build directory using emcmake + mkdir out/cmake-wasm + cd out/cmake-wasm + +-# Make sure the path is to the source checkout of Emscripten, not emsdk's release. ++# If using Emscripten v4.0.2 or lower, make sure the path is to the source checkout of Emscripten, not emsdk's release. + emcmake cmake -GNinja -DDAWN_EMSCRIPTEN_TOOLCHAIN="path/to/emscripten" ../.. + + ninja +diff --git a/third_party/emdawnwebgpu/webgpu.cpp b/third_party/emdawnwebgpu/webgpu.cpp +index f1c5a7d50e..16f2495712 100644 +--- a/third_party/emdawnwebgpu/webgpu.cpp ++++ b/third_party/emdawnwebgpu/webgpu.cpp +@@ -131,7 +131,6 @@ class RefCounted : NonMovable { + bool Release() { + if (mRefCount.fetch_sub(1u, std::memory_order_release) == 1u) { + std::atomic_thread_fence(std::memory_order_acquire); +- emwgpuDelete(this); + return true; + } + return false; +@@ -234,6 +233,7 @@ class Ref { + static void Release(T value) { + if (value != nullptr && value->RefCounted::Release()) { + delete value; ++ emwgpuDelete(value); + } + } + +@@ -641,7 +641,8 @@ struct WGPUAdapterImpl final : public EventSource, public RefCounted { + struct WGPUBufferImpl final : public EventSource, + public RefCountedWithExternalCount { + public: +- WGPUBufferImpl(const EventSource* source, bool mappedAtCreation); ++ WGPUBufferImpl(const EventSource* source, bool mappedAtCreation, bool isExternal); ++ ~WGPUBufferImpl(); + + void Destroy(); + const void* GetConstMappedRange(size_t offset, size_t size); +@@ -671,6 +672,7 @@ struct WGPUBufferImpl final : public EventSource, + }; + MapRequest mPendingMapRequest; + WGPUBufferMapState mMapState; ++ bool mIsExternal; + }; + + struct WGPUQueueImpl final : public EventSource, public RefCounted { +@@ -1164,11 +1166,15 @@ WGPUAdapter emwgpuCreateAdapter(const EventSource* source) { + + WGPUBuffer emwgpuCreateBuffer(const EventSource* source, + bool mappedAtCreation = false) { +- return new WGPUBufferImpl(source, mappedAtCreation); ++ return new WGPUBufferImpl(source, mappedAtCreation, true); + } + + WGPUDevice emwgpuCreateDevice(const EventSource* source, WGPUQueue queue) { +- return new WGPUDeviceImpl(source, queue); ++ // This function is only called from JS via `importJsDevice()`, which ++ // needs to increment the external ref count to fix the behavior. ++ WGPUDeviceImpl* device = new WGPUDeviceImpl(source, queue); ++ device->AddExternalRef(); ++ return device; + } + + WGPUQueue emwgpuCreateQueue(const EventSource* source) { +@@ -1275,15 +1281,22 @@ WGPUAdapterImpl::WGPUAdapterImpl(const EventSource* source) + // WGPUBuffer implementations. + // ---------------------------------------------------------------------------- + +-WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, bool mappedAtCreation) ++WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, bool mappedAtCreation, bool isExternal) + : EventSource(source), + mMapState(mappedAtCreation ? WGPUBufferMapState_Mapped +- : WGPUBufferMapState_Unmapped) { ++ : WGPUBufferMapState_Unmapped), ++ mIsExternal(isExternal) { + if (mappedAtCreation) { + mPendingMapRequest = {kNullFutureId, WGPUMapMode_Write}; + } + } + ++WGPUBufferImpl::~WGPUBufferImpl() { ++ if (!mIsExternal) { ++ Destroy(); ++ } ++} ++ + void WGPUBufferImpl::Destroy() { + emwgpuBufferDestroy(this); + AbortPendingMap("Buffer was destroyed before mapping was resolved."); +@@ -1504,6 +1517,7 @@ WGPUFuture WGPUShaderModuleImpl::GetCompilationInfo( + void wgpu##Name##Release(WGPU##Name o) { \ + if (o->Release()) { \ + delete o; \ ++ emwgpuDelete(o); \ + } \ + } + WGPU_OBJECTS(DEFINE_WGPU_DEFAULT_ADDREF_RELEASE) +@@ -1638,7 +1652,7 @@ void wgpuBufferUnmap(WGPUBuffer buffer) { + + WGPUBuffer wgpuDeviceCreateBuffer(WGPUDevice device, + const WGPUBufferDescriptor* descriptor) { +- WGPUBuffer buffer = new WGPUBufferImpl(device, descriptor->mappedAtCreation); ++ WGPUBuffer buffer = new WGPUBufferImpl(device, descriptor->mappedAtCreation, false); + emwgpuDeviceCreateBuffer(device, descriptor, buffer); + return buffer; + } diff --git a/cmake/winml_sdk_helpers.cmake b/cmake/winml_sdk_helpers.cmake index 9241fcd060caf..ca657311b7f14 100644 --- a/cmake/winml_sdk_helpers.cmake +++ b/cmake/winml_sdk_helpers.cmake @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -cmake_minimum_required(VERSION 3.0) +cmake_minimum_required(VERSION 3.5) # utility function(convert_forward_slashes_to_back input output) diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj b/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj index f00a08a1a3595..b1452a64934c2 100644 --- a/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj +++ b/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj @@ -8,7 +8,7 @@ - + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs index 13117f23e8ef9..8916f11919cfe 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs @@ -25,7 +25,7 @@ internal class ManagedTypeProjection /// /// /// - /// OrtValye created accoding to the metadata + /// OrtValue created according to the metadata internal static OrtValue CreateProjection(NamedOnnxValue namedOnnxValue, NodeMetadata metadata) { OrtValue result; @@ -191,4 +191,3 @@ private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata } } } - diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index d628b065ceaa7..b64a5c3e5a4a2 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -847,7 +847,7 @@ internal class NativeLib /// Creates an instance of OrtSession with provided parameters /// /// Native OrtEnv instance - /// Byte array correspoonding to the model + /// Byte array corresponding to the model /// Size of the model in bytes /// Native SessionOptions instance /// Native OrtPrepackedWeightsContainer instance @@ -1258,7 +1258,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// /// Native SessionOptions instance /// Name of the initializer - /// Native OrtValue instnce + /// Native OrtValue instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options, byte[] /*(const char*)*/ name, diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index c9a15de9ef897..2245ff5791feb 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -20,6 +20,7 @@ struct ComputeCapability; class KernelRegistry; struct KernelCreateInfo; class Node; +class GraphOptimizerRegistry; } // namespace onnxruntime #else #include @@ -129,10 +130,25 @@ class IExecutionProvider { and decide whether a node will be assigned to <*this> execution provider. For kernels registered in a kernel registry, `kernel_lookup` must be used to find a matching kernel for this EP. + + The graph_optimizer_registry is designed for enabling L2+ graph optimizations tailored for EPs. + These optimizations are applied after the graph partitioner assigns ComputeCapability to the EP + and before EP's "Compile" or fusion. + + Steps to use graph_optimizer_registry and create the optimization ComputeCapability: + 1. Lookup Optimizer: The EP calls provider bridge API to lookup pre-defined optimizer by name and get selection function. + - Example: g_host->GetOptimizerByName(optimizer_name, graph_optimizer_registry, selection_func) + 2. Run Selection Function: The EP executes the selection function to obtain the selection ComputeCapability. + - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization. + 3. Create Optimization ComputeCapability: The EP uses the selection ComputeCapability to create the optimization ComputeCapability. + 4. Return ComputeCapability: The EP returns the final ComputeCapability, with nodes_to_optimize set to the optimization ComputeCapability. + + Note: For more detailed implementations of using graph_optimizer_registry, please refer to TensorRT EP. */ virtual std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant = nullptr) const; /** diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 7798394b045dc..35b568e3f8e28 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -27,6 +27,7 @@ #include "core/common/span_utils.h" #include "core/common/status.h" #include "core/common/logging/logging.h" +#include "core/framework/ort_value.h" #include "core/framework/prepacked_weights_container.h" #include "core/graph/onnx_protobuf.h" #include "core/graph/basic_types.h" @@ -39,6 +40,9 @@ #include "core/graph/node_arg.h" #include "core/graph/ort_format_load_options.h" +// Type from Model Editor API in ORT C API so can't be in a namespace +struct OrtGraph; + namespace onnxruntime { class Graph; struct IndexedSubGraph; @@ -763,6 +767,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi */ bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const; + /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. + */ + bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const; + /** Gets all the initializer tensors in this Graph. */ const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; } @@ -1430,6 +1438,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph); + static Status LoadFromModelEditorApiModel(const OrtGraph& api_graph, + const Model& owning_model, + const std::unordered_map& domain_to_version, + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + bool strict_shape_type_inference, + const logging::Logger& logger, + std::unique_ptr& graph); + + Status UpdateUsingModelEditorApiModel(const OrtModel& api_model); + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const { return runtime_optimizations_; @@ -1630,7 +1648,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // Implementation for initializer replacement Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external); - std::vector CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, + template // range-initializer returning std::string + std::vector CreateNodeArgs(const StringRange& names, const ArgNameToTypeMap& name_to_type_map); void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const; @@ -1694,6 +1713,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return nodes_[node_index].get(); } + Status LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updating_existing_graph = false); + const Model& owning_model_; // GraphProto to store name, version, initializer. @@ -1708,6 +1729,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi InitializedTensorSet name_to_initial_tensor_; + // Initializers that are external to the Graph. + // e.g. created from existing memory using CreateTensorWithDataAndDeleterAsOrtValue in the ORT API. + // As we need to convert to TensorProto for the optimizers to work and keep the deleter information we store them + // in the Graph instance and retrieve during session state finalization. + std::unordered_map ortvalue_initializers_; + std::unordered_set, std::hash, std::equal_to> sparse_tensor_names_; @@ -1744,6 +1771,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // in some case, a fused sub-graph will happens multiple times in one model, we use a map // to store reusable-schema in lookup. InlinedHashMap> reusable_fused_schema_map_; + #endif // !defined(ORT_MINIMAL_BUILD) // Graph nodes. @@ -1806,7 +1834,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::unordered_map> node_arg_to_consumer_nodes_; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - const std::unordered_map domain_to_version_; + std::unordered_map domain_to_version_; // Model IR version. Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION}; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 9385e2f092e58..6a664d8be9c05 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -193,6 +193,12 @@ class GraphViewer { IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return graph_->GetSchemaRegistry(); } #endif + /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. + */ + bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const { + return graph_->GetOrtValueInitializer(name, value); + } + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer); GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info); diff --git a/include/onnxruntime/core/graph/indexed_sub_graph.h b/include/onnxruntime/core/graph/indexed_sub_graph.h index e457d3dcad1f1..088db79a7e005 100644 --- a/include/onnxruntime/core/graph/indexed_sub_graph.h +++ b/include/onnxruntime/core/graph/indexed_sub_graph.h @@ -72,6 +72,12 @@ struct IndexedSubGraph { return meta_def_.get(); } + /** Gets the mutable meta definition needed to represent this subgraph as a FunctionProto. + @returns MetaDef instance if it has been set. nullptr if not. */ + MetaDef* GetMutableMetaDef() { + return meta_def_.get(); + } + // Check if the accounting is enabled for the current EP bool IsAccountingEnabled() const { return resource_accountant != nullptr && diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 47e6389492f30..098de14bdfd61 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -305,6 +305,10 @@ ORT_RUNTIME_CLASS(OpAttr); ORT_RUNTIME_CLASS(Logger); ORT_RUNTIME_CLASS(ShapeInferContext); ORT_RUNTIME_CLASS(LoraAdapter); +ORT_RUNTIME_CLASS(ValueInfo); +ORT_RUNTIME_CLASS(Node); +ORT_RUNTIME_CLASS(Graph); +ORT_RUNTIME_CLASS(Model); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -665,6 +669,9 @@ typedef struct OrtApi OrtApi; struct OrtTrainingApi; typedef struct OrtTrainingApi OrtTrainingApi; +struct OrtModelEditorApi; +typedef struct OrtModelEditorApi OrtModelEditorApi; + /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -847,7 +854,8 @@ struct OrtApi { * * \snippet{doc} snippets.dox OrtStatus Return Value */ - ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, + ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); /** \brief Run the model in an ::OrtSession @@ -1340,6 +1348,8 @@ struct OrtApi { * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. * p_data is owned by caller. ReleaseValue won't release p_data. * + * If you wish to transfer ownership of p_data to ORT use CreateTensorWithDataAndDeleterAsOrtValue. + * * \param[in] info Memory description of where the p_data buffer resides (CPU vs GPU etc). * \param[in] p_data Pointer to the data buffer. * \param[in] p_data_len The number of bytes in the data buffer. @@ -1997,7 +2007,8 @@ struct OrtApi { /** \brief Get the value type from an ::OrtMapTypeInfo * * \param[in] map_type_info - * \param[out] type_info + * \param[out] type_info A copy of the OrtTypeInfo for the map value type. + * The user must free this value with ReleaseTypeInfo. * * \snippet{doc} snippets.dox OrtStatus Return Value */ @@ -2012,7 +2023,8 @@ struct OrtApi { * This is used by WinML to support model reflection APIs. * * \param[in] sequence_type_info - * \param[out] type_info + * \param[out] type_info A copy of the OrtTypeInfo for the sequence element type. + * The user must free this value with ReleaseTypeInfo. * * \snippet{doc} snippets.dox OrtStatus Return Value */ @@ -2887,7 +2899,8 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /** \brief Create session from memory with prepacked weights container @@ -2910,7 +2923,8 @@ struct OrtApi { */ ORT_API2_STATUS(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /// @} @@ -4293,8 +4307,8 @@ struct OrtApi { * specific type that is described by the returned ::OrtTypeInfo. * * \param[in] optional_type_info - * \param[out] out A pointer to the ::OrtTypeInfo for what the optional value could be. - * it is owned by OrtOptionalTypeInfo instance. + * \param[out] out A copy of ::OrtTypeInfo for what the optional value could be. + * The user must free this value with ReleaseTypeInfo. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -4786,6 +4800,75 @@ struct OrtApi { */ ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); + + /** \brief Release an OrtValueInfo instance if it was not added to an OrtGraph. + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(ValueInfo); + + /** \brief Release an OrtNode if it was not added to an OrtGraph. + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Node); + + /** \brief Release an OrtGraph. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Graph); + + /** \brief Release an OrtModel. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Model); + + /** \brief Get the value name from an OrtValueInfo instance. + * \param[in] value_info The OrtValueInfo instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_API2_STATUS(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); + + /** \brief Get the type information from an OrtValueInfo instance. + * \param[in] value_info The OrtValueInfo instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_API2_STATUS(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); + + /** \brief Get the Model Editor API instance + * + * Get the Model Editor API instance to create a new model or augment an existing model. + * + * \return Model Editor API struct + * + * \since Version 1.21. + */ + const OrtModelEditorApi*(ORT_API_CALL* GetModelEditorApi)(); + + /** \brief Create an OrtValue for a Tensor that uses pre-existing memory. + * + * ORT will take ownership of the memory and free it using the provided deleter when no longer in use. + * + * \param[in] deleter OrtAllocator instance that will be used to free the memory. + * Only the OrtAllocator:Info and OrtAllocator::Release functions are required. + * The OrtMemoryInfo returned by OrtAllocator::Info must match the location of p_data. + * \param[in] p_data Pointer to the memory that will be used by the Tensor. ORT will take ownership of the memory. + * \param[in] p_data_len Length of the memory in bytes. + * \param[in] shape Dimensions of the Tensor. All values should be > 0. + * \param[in] shape_len Number of dimensions in the shape array. + * \param[in] type Data type of the Tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); }; /* @@ -4900,6 +4983,400 @@ struct OrtCustomOp { void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); }; +/** + * ORT Model Editor API + */ + +/** + * \brief The OrtModelEditorApi struct provides functions to create or edit an ONNX model. + * + * See onnxruntime/test/shared_lib/test_model_editor_api.cc for example usage. + * + * \since Version 1.21. + */ +struct OrtModelEditorApi { + // Model building/editing requires a full build. We return nullptr from GetModelEditorApi if this is a minimal + // build, so it doesn't matter if there are no function pointers in this struct as a user will never get an + // OrtModelEditorApi instance. We do however need a dummy field to avoid empty struct warning. +#if defined(ORT_MINIMAL_BUILD) + const bool not_defined_in_this_build; +#else + /** \brief Create an OrtTypeInfo instance for a Tensor. + * + * Create an OrtTypeInfo instance for a Tensor to use as graph inputs/outputs with the Model Editor API. + * + * User can release `tensor_info` after creating the OrtTypeInfo. + * + * \param[in] tensor_info Tensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a SparseTensor. + * + * Create an OrtTypeInfo instance for a SparseTensor to use as graph inputs/outputs with the Model Editor API. + * + * User can release `tensor_info` after creating the OrtTypeInfo. + * + * \param[in] tensor_info SparseTensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a Map. + * + * Create an OrtTypeInfo instance for a Map to use as graph inputs/outputs with the Model Editor API. + * + * User can release `map_value_type` after creating the OrtTypeInfo. + * + * \param[in] map_key_type Key type for the map. + * \param[in] map_value_type Value type for the map. + * \param[out] TypeInfo instance for the map. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a Sequence. + * + * Create an OrtTypeInfo instance for a Sequence to use as graph inputs/outputs with the Model Editor API. + * + * User can release `sequence_type` after creating the OrtTypeInfo. + * + * \param[in] sequence_type Sequence type and shape information. + * \param[out] TypeInfo instance for the sequence. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for an Optional. + * + * Create an OrtTypeInfo instance for an Optional to use as graph inputs/outputs with the Model Editor API. + * + * User can release `contained_type` after creating the OrtTypeInfo. + * + * \param[in] tensor_info Tensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtValueInfo for use as an OrtGraph input or output. + * + * \param[in] name The name of the input or output. + * \param[in] type_info The type information for the input or output. The provided value is copied. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info); + + /** \brief Create an OrtNode to add to an OrtGraph. + * + * Create an OrtNode. + * + * Create attributes with CreateOpAttr. OrtOpAttr instances are copied. + * + * \param[in] operator_name The name of the operator. + * \param[in] domain_name The domain of the operator. Use an empty string for ONNX operators. + * \param[in] node_name The name of the node. + * \param[in] input_names The names of the inputs. + * \param[in] input_names_len The number of input names. + * \param[in] output_names The names of the outputs. + * \param[in] output_names_len The number of output names. + * \param[in] attributes The optional attributes of the node. + * \param[in] attribs_len The number of attributes. May be zero. + * \param[out] node The OrtNode instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, _In_ const char* domain_name, _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len, + _Outptr_ OrtNode** node); + + /** \brief Create an OrtGraph + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateGraph, _Outptr_ OrtGraph** graph); + + /** \brief Set the inputs for the OrtGraph. + * + * Set the graph inputs. This will replace any existing inputs with the new values. + * The OrtGraph takes ownership of the OrtValueInfo instances and you should NOT call ReleaseOrtValueInfo. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] inputs The input OrtValueInfo instances. + * \param[in] inputs_len The number of input OrtValueInfo instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SetGraphInputs, _Inout_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); + + /** \brief Set the outputs for the OrtGraph. + * + * Set the graph outputs. This will replace any existing outputs with the new values. + * The OrtGraph takes ownership of the OrtValueInfo instances provided and you should NOT call ReleaseOrtValueInfo. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] outputs The output OrtValueInfo instances. + * \param[in] outputs_len The number of output OrtValueInfo instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SetGraphOutputs, _Inout_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); + + /** \brief Add an initializer to the OrtGraph + * + * ORT will take ownership of the OrtValue and you should NOT call ReleaseOrtValue. + * + * Two options: + * + * Allocated memory: + * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. + * Set `data_is_external` to false. + * + * Pre-existing memory: + * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue + * with a tensor that contains a pointer to the existing data. + * Set `data_is_external` to true. + * + * The pointer must remain valid for the duration of the inference session. + * If using CreateTensorWithDataAsOrtValue you are responsible for freeing the memory after the inference session + * is released. + * If using CreateTensorWithDataAndDeleterAsOrtValue, ORT will free the memory using the provided deleter as + * soon as the OrtValue is no longer in use. + * + * NOTE: A tensor containing pre-existing memory MUST have 128 bytes of data or more. + * For smaller tensors use CreateTensorAsOrtValue. + * + * ONNX shape inferencing does not support external data. An initializer involved in shape inferencing is + * typically small (a single value or limited by the rank of a tensor) and uses less than 128 bytes of + * memory, so this limit acts as a simple catch-all rule to avoid issues. + * e.g. Reshape's `shape`, Clip's `min` and `max`, various ops `axes`. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] name The value name for the initializer. + * \param[in] tensor The OrtValue instance containing the tensor data. + * \param[in] data_is_external Set to true if the data is external and should not be copied. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddInitializerToGraph, _Inout_ OrtGraph* graph, _In_ const char* name, _In_ OrtValue* tensor, + bool data_is_external); + + /** \brief Add an OrtNode to an OrtGraph + * + * Add the node to the graph. The OrtGraph will take ownership of OrtNode and you should NOT call ReleaseOrtNode. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] node The OrtNode instance to add to the graph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddNodeToGraph, _Inout_ OrtGraph* graph, _In_ OrtNode* node); + + /** \brief Create an OrtModel. + * + * Create an OrtModel. + * + * This can be used to build a new model, or to augment an existing model. + * + * \param[in] domain_names The domain names for the model. + * If augmenting an existing model add additional domains if needed. + * \param[in] opset_versions The opset versions for the model. + * If augmenting an existing model add additional opset versions if needed. + * \param[in] opset_entries_len The number of domain_names and opset_versions entries. + * Domain and opset entries should be 1:1 + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model); + + /** \brief Add an OrtGraph to an OrtModel. + * + * Add the graph to a model. This should be called once when creating a new model. + * + * The OrtModel takes ownership of the OrtGraph and you should NOT call ReleaseOrtGraph. + * + * \param[in] model The OrtModel instance to update. + * \param[in] graph The OrtGraph instance to add to the model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddGraphToModel, _Inout_ OrtModel* model, _In_ OrtGraph* graph); + + /** \brief Create an OrtSession using the OrtModel. + * + * Create an inference session using the OrtModel instance. + * The OrtModel should have been populated with an OrtGraph containing nodes and initializers, and SetGraphInputs + * and SetGraphOutputs must have been called. + * This will validate the model, run optimizers, and prepare the session for inferencing. + * + * ReleaseOrtModel must be called to free the OrtModel after session creation. + * + * \param[in] env The OrtEnv instance. + * \param[in] model The OrtModel instance. + * \param[in] options The OrtSessionOptions instance. + * \param[out] out The OrtSession instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession to augment an existing model. + * + * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. + * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the + * model is finalized. + * + * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. + * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. + * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made + * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. + * + * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the + * session for inferencing by calling FinalizeModelEditorSession. + * + * \param{in} env The OrtEnv instance. + * \param{in} model_path The path to the existing ONNX model to augment. + * \param{in} options The OrtSessionOptions instance. + * \param{out} out The created OrtSession instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModelEditorSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession to augment an existing model. + * + * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. + * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the + * model is finalized. + * + * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. + * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. + * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made + * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. + * + * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the + * session for inferencing by calling FinalizeModelEditorSession. + * + * \param{in} env The OrtEnv instance. + * \param{in} model_data The model data for the existing model to augment. + * \param{in} model_data_length The length of the model data. + * \param{in} options The OrtSessionOptions instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + + /** \brief Query the session for the opset version of a domain. + * + * When using the Model Editor API to augment a model, any new nodes must conform to the opset version of the + * original model. To do that the user must be able to discover that opset version. + * + * \param[in] session OrtSession to query + * \param[in] domain Domain to query. The ONNX domain is an empty string. + * \param[out] opset The opset version of the domain. + * + * \snippet{doc} snippets.dox OrtStatus Return Value. Returns an error if the domain is not used in the model. + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, _Out_ int* opset); + + /** \brief Apply changes to augment the ONNX model in a session created using CreateModelEditorSession[FromArray] + * + * Adds new nodes and updates graph inputs/outputs using `model` to augment the original ONNX model in the session. + * All changes will be validated. + * Call FinalizeModelEditorSession to prepare the session for inferencing. + * + * Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel. + * i.e. you don't need to call SetGraphInputs/SetGraphOutputs if they are unchanged. + * + * ReleaseOrtModel must be called to free the OrtModel after it is applied to the session. + * + * \param[in] session OrtSession to update. Session must have been created using CreateModelEditorSession[FromArray]. + * \param[in] model OrtModel containing new nodes, new initializers, and updated graph input and/or output info. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(ApplyModelToModelEditorSession, _Inout_ OrtSession* session, _In_ OrtModel* model); + + /** \brief Finalize the Model Editor session that was created using CreateModelEditorSession[FromArray]. + * + * Finalize the Model Editor session that augmented an ONNX model by adding new nodes. + * This will run optimizers and prepare the session for inferencing. + * + * \param[in] session OrtSession to finalize. Session must have been created using CreateModelEditorSession[FromArray]. + * \param[in] options OrtSessionOptions to use for the session. + * \param[in] Optional prepacked_weights_container OrtPrepackedWeightsContainer to use for the session. + Set to nullptr if not used. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(FinalizeModelEditorSession, _Inout_ OrtSession* session, _In_ const OrtSessionOptions* options, + _In_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container); +#endif // !defined(ORT_MINIMAL_BUILD) +}; + /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 123ef98901003..979b478e2fbb4 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -26,16 +26,17 @@ #include "onnxruntime_c_api.h" #include "onnxruntime_float16.h" +#include #include #include -#include #include #include #include -#include +#include #include #include -#include +#include +#include #ifdef ORT_NO_EXCEPTIONS #include @@ -120,7 +121,7 @@ const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); #endif #endif -/// This returns a reference to the OrtApi interface in use +/// This returns a reference to the ORT C API. inline const OrtApi& GetApi() noexcept { return *Global::api_; } /// @@ -143,6 +144,20 @@ std::string GetBuildInfoString(); /// vector of strings std::vector GetAvailableProviders(); +/// +/// This returns a reference to the ORT C Model Editor API. Used if building or augmenting a model at runtime. +/// +/// ORT C Model Editor API reference +inline const OrtModelEditorApi& GetModelEditorApi() { + auto* api = GetApi().GetModelEditorApi(); + if (api == nullptr) { + // minimal build + ORT_CXX_API_THROW("Model Editor API is not available in this build", ORT_FAIL); + } + + return *api; +} + /** \brief IEEE 754 half-precision floating point data type * * \details This struct is used for converting float to float16 and back @@ -523,6 +538,10 @@ ORT_DEFINE_RELEASE(Status); ORT_DEFINE_RELEASE(OpAttr); ORT_DEFINE_RELEASE(Op); ORT_DEFINE_RELEASE(KernelInfo); +ORT_DEFINE_RELEASE(ValueInfo); +ORT_DEFINE_RELEASE(Node); +ORT_DEFINE_RELEASE(Graph); +ORT_DEFINE_RELEASE(Model); #undef ORT_DEFINE_RELEASE @@ -559,7 +578,9 @@ struct Base { constexpr Base() = default; constexpr explicit Base(contained_type* p) noexcept : p_{p} {} - ~Base() { OrtRelease(p_); } + ~Base() { + OrtRelease(p_); + } Base(const Base&) = delete; Base& operator=(const Base&) = delete; @@ -635,9 +656,13 @@ struct AllocatedFree { struct AllocatorWithDefaultOptions; struct Env; +struct Graph; +struct Model; +struct Node; +struct ModelMetadata; struct TypeInfo; struct Value; -struct ModelMetadata; +struct ValueInfo; /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators * and release them at the end of the scope. The lifespan of the given allocator @@ -1051,6 +1076,10 @@ struct ConstSessionImpl : Base { size_t GetOutputCount() const; ///< Returns the number of model outputs size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden + std::vector GetInputNames() const; + std::vector GetOutputNames() const; + std::vector GetOverridableInitializerNames() const; + /** \brief Returns a copy of input name at the specified index. * * \param index must less than the value returned by GetInputCount() @@ -1084,6 +1113,12 @@ struct ConstSessionImpl : Base { TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo + + int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain + + // Will move before checkin if that's the case. + std::vector GetInputs() const; + std::vector GetOutputs() const; }; template @@ -1161,6 +1196,9 @@ struct SessionImpl : ConstSessionImpl { * \param[in] kv_len Number of elements in the keys and values arrays */ void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len); + + void FinalizeModelEditorSession(const Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); }; } // namespace detail @@ -1172,13 +1210,34 @@ using UnownedSession = detail::SessionImpl>; * */ struct Session : detail::SessionImpl { - explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used - Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession + /// Create an empty Session object, must be assigned a valid one to be used. Wraps OrtApi::CreateSession + explicit Session(std::nullptr_t) {} + explicit Session(OrtSession* p) : SessionImpl{p} {} ///< C API Interop + + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); + + /// Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer - Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray + OrtPrepackedWeightsContainer* prepacked_weights_container); + + /// Wraps OrtApi::CreateSessionFromArray + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); + + /// Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer + OrtPrepackedWeightsContainer* prepacked_weights_container); + +#if !defined(ORT_MINIMAL_BUILD) + /// Wraps OrtModelEditorApi::CreateSessionFromModel + Session(const Env& env, const Model& model, const SessionOptions& options); + + /// Wraps OrtModelEditorApi::CreateModelEditorSession + static Session CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); + + /// Wraps OrtModelEditorApi::CreateModelEditorSession + static Session CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options); +#endif // !defined(ORT_MINIMAL_BUILD) ConstSession GetConst() const { return ConstSession{this->p_}; } UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } @@ -1210,7 +1269,7 @@ using ConstMemoryInfo = detail::MemoryInfoImpl { static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created - explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C Api + explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } }; @@ -1233,6 +1292,7 @@ struct TensorTypeAndShapeInfoImpl : Base { [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions + std::vector GetSymbolicDimensions() const; std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape }; @@ -1248,8 +1308,18 @@ struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl; using Base::Base; - explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used - explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API + /// Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} + /// Used for interop with the C API + explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} + + // Create a TensorTypeAndShapeInfo object with the specified element type and dimensions + // symbolic_dims are optional, but should be 1:1 with dims. + // The value in symbolic_dims will be used for all entries in dims that are -1. + explicit TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, + const std::vector& dims, + const std::vector* symbolic_dims = nullptr); + ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } }; @@ -1344,9 +1414,18 @@ struct TypeInfo : detail::TypeInfoImpl { using Base = detail::TypeInfoImpl; using Base::Base; - explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used + /// Create an empty TypeInfo object, must be assigned a valid one to be used + explicit TypeInfo(std::nullptr_t) {} explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop +#if !defined(ORT_MINIMAL_BUILD) + static TypeInfo CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_info); + static TypeInfo CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_info); + static TypeInfo CreateSequenceTypeInfo(ConstTypeInfo sequence_type); + static TypeInfo CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type); + static TypeInfo CreateOptionalTypeInfo(ConstTypeInfo contained_type); +#endif // !defined(ORT_MINIMAL_BUILD) + ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; } }; @@ -1701,7 +1780,8 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. */ template - static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); + static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, + const int64_t* shape, size_t shape_len); /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. * @@ -1712,11 +1792,25 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ - static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAndDeleterAsOrtValue. + * + * \param deleter OrtAllocator that will be used to free the buffer when no longer required. + * \param p_data Pointer to the data buffer. + * \param p_data_byte_count The number of bytes in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. - * This overload will allocate the buffer for the tensor according to the supplied shape and data type. + * This overload will allocate the buffer for the tensor according to the supplied shape and data type. * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. * The input data would need to be copied into the allocated buffer. * This API is not suitable for strings. @@ -1740,7 +1834,8 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a Map Onnx type representation. * The API would ref-count the supplied OrtValues and they will be released @@ -2437,6 +2532,9 @@ struct CustomOpBase : OrtCustomOp { return std::vector{}; } + // Ort::CustomOpBase derived class should provide the following static method with the type/shape inferencing + // implementation if needed: + // static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context) template decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) { OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { @@ -2459,6 +2557,129 @@ struct CustomOpBase : OrtCustomOp { int end_ver_ = MAX_CUSTOM_OP_END_VER; }; -} // namespace Ort +namespace detail { +template +struct ValueInfoImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + std::string Name() const; + ConstTypeInfo TypeInfo() const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstValueInfo = detail::ValueInfoImpl>; +/** \brief Wrapper around ::OrtValueInfo + * + */ +struct ValueInfo : detail::ValueInfoImpl { + explicit ValueInfo(std::nullptr_t) {} ///< No instance is created + /// Take ownership of a pointer created by C API + explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl{p} {} + + // Create ValueInfo for a tensor + explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info); + + ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; } +}; + +namespace detail { +template +struct NodeImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; +}; +} // namespace detail + +/** \brief Wrapper around ::OrtNode + * + */ +struct Node : detail::NodeImpl { + explicit Node(std::nullptr_t) {} ///< No instance is created + explicit Node(OrtNode* p) : NodeImpl{p} {} ///< Take ownership of a pointer created by C API + +#if !defined(ORT_MINIMAL_BUILD) + Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names); + + /// + /// Wraps CreateNode. Node takes ownership of attributes on success and updates the OpAttr in `attributes` to do so. + /// + Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes); + + private: + static void Init(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes, + OrtNode*& node); +#endif // !defined(ORT_MINIMAL_BUILD) +}; + +namespace detail { +template +struct GraphImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + +#if !defined(ORT_MINIMAL_BUILD) + void SetInputs(std::vector& inputs); + void SetOutputs(std::vector& outputs); + void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value + void AddNode(Node& node); // Graph takes ownership of Node +#endif // !defined(ORT_MINIMAL_BUILD) +}; +} // namespace detail + +/** \brief Wrapper around ::OrtGraph + * + */ +struct Graph : detail::GraphImpl { + explicit Graph(std::nullptr_t) {} ///< No instance is created + explicit Graph(OrtGraph* p) : GraphImpl{p} {} ///< Take ownership of a pointer created by C API +#if !defined(ORT_MINIMAL_BUILD) + Graph(); +#endif +}; + +namespace detail { +template +struct ModelImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + +#if !defined(ORT_MINIMAL_BUILD) + void AddGraph(Graph& graph); +#endif +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstModel = detail::ModelImpl>; + +/** \brief Wrapper around ::OrtModel + * + */ +struct Model : detail::ModelImpl { + using DomainOpsetPair = std::pair; + + explicit Model(std::nullptr_t) {} ///< No instance is created + explicit Model(OrtModel* p) : ModelImpl{p} {} ///< Take ownership of a pointer created by C API + +#if !defined(ORT_MINIMAL_BUILD) + explicit Model(const std::vector& opsets); +#endif + + ConstModel GetConst() const { return ConstModel{this->p_}; } +}; +} // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 3aeb9412f350e..48c5e52e33c53 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -10,7 +10,9 @@ #include #include #include +#include #include +#include // Convert OrtStatus to Ort::Status and return // instead of throwing @@ -995,6 +997,59 @@ inline size_t ConstSessionImpl::GetOverridableInitializerCount() const { return out; } +template +inline std::vector ConstSessionImpl::GetInputNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_inputs = GetInputCount(); + std::vector input_names; + input_names.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name)); + input_names.push_back(name); + allocator.Free(name); + } + + return input_names; +} + +template +inline std::vector ConstSessionImpl::GetOutputNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_inputs = GetOutputCount(); + std::vector output_names; + output_names.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name)); + output_names.push_back(name); + allocator.Free(name); + } + + return output_names; +} + +template +inline std::vector ConstSessionImpl::GetOverridableInitializerNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_initializers = GetOverridableInitializerCount(); + std::vector initializer_names; + initializer_names.reserve(num_initializers); + + for (size_t i = 0; i < num_initializers; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name)); + initializer_names.push_back(name); + } + + return initializer_names; +} + template inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { char* out; @@ -1051,6 +1106,45 @@ inline TypeInfo ConstSessionImpl::GetOverridableInitializerTypeInfo(size_t in return TypeInfo{out}; } +#if !defined(ORT_MINIMAL_BUILD) +template +inline int ConstSessionImpl::GetOpset(const std::string& domain) const { + int opset; + ThrowOnError(GetModelEditorApi().SessionGetOpsetForDomain(this->p_, domain.c_str(), &opset)); + return opset; +} +#endif // !defined(ORT_MINIMAL_BUILD) + +template +std::vector ConstSessionImpl::GetInputs() const { + const std::vector input_names = GetInputNames(); + + std::vector inputs; + inputs.reserve(input_names.size()); + + for (size_t i = 0; i < input_names.size(); ++i) { + auto type_info = GetInputTypeInfo(i); + inputs.emplace_back(ValueInfo{input_names[i], type_info.GetConst()}); + } + + return inputs; +} + +template +std::vector ConstSessionImpl::GetOutputs() const { + const std::vector output_names = GetOutputNames(); + + std::vector outputs; + outputs.reserve(output_names.size()); + + for (size_t i = 0; i < output_names.size(); ++i) { + auto type_info = GetOutputTypeInfo(i); + outputs.emplace_back(ValueInfo{output_names[i], type_info.GetConst()}); + } + + return outputs; +} + template inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, size_t output_count) { @@ -1098,6 +1192,15 @@ inline void SessionImpl::SetEpDynamicOptions(const char* const* keys, const c ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len)); } +#if !defined(ORT_MINIMAL_BUILD) +template +inline void SessionImpl::FinalizeModelEditorSession(const Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container) { + ThrowOnError(GetModelEditorApi().ApplyModelToModelEditorSession(this->p_, model)); + ThrowOnError(GetModelEditorApi().FinalizeModelEditorSession(this->p_, options, prepacked_weights_container)); +} +#endif // #if !defined(ORT_MINIMAL_BUILD) + } // namespace detail inline SessionOptions::SessionOptions() { @@ -1144,6 +1247,32 @@ inline Session::Session(const Env& env, const void* model_data, size_t model_dat prepacked_weights_container, &this->p_)); } +#if !defined(ORT_MINIMAL_BUILD) +inline Session::Session(const Env& env, const Model& model, const SessionOptions& options) { + ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_)); +} + +// static +inline Session Session::CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path, + const SessionOptions& options) { + OrtSession* session = nullptr; + ThrowOnError(GetModelEditorApi().CreateModelEditorSession(env, model_path, options, &session)); + return Session(session); +} + +// static +inline Session Session::CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options) { + OrtSession* session = nullptr; + ThrowOnError(GetModelEditorApi().CreateModelEditorSessionFromArray(env, model_data, model_data_length, options, + &session)); + return Session(session); +} + +void FinalizeModelEditorSession(const Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); +#endif // #if !defined(ORT_MINIMAL_BUILD) + inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); @@ -1211,6 +1340,59 @@ inline int64_t ModelMetadata::GetVersion() const { return out; } +inline TensorTypeAndShapeInfo::TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, + const std::vector& dims, + const std::vector* symbolic_dims) { + ThrowOnError(GetApi().CreateTensorTypeAndShapeInfo(&p_)); + ThrowOnError(GetApi().SetTensorElementType(p_, element_type)); + ThrowOnError(GetApi().SetDimensions(p_, dims.data(), dims.size())); + + if (symbolic_dims) { + std::vector symbolic_dims_cstr; + symbolic_dims_cstr.reserve(symbolic_dims->size()); + std::transform(symbolic_dims->begin(), symbolic_dims->end(), std::back_inserter(symbolic_dims_cstr), + [](const std::string& s) { return s.c_str(); }); + ThrowOnError(GetApi().SetSymbolicDimensions(p_, symbolic_dims_cstr.data(), symbolic_dims_cstr.size())); + } +} + +#if !defined(ORT_MINIMAL_BUILD) +// static +inline TypeInfo TypeInfo::CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_type_and_shape_info) { + OrtTypeInfo* output = nullptr; + ThrowOnError(GetModelEditorApi().CreateTensorTypeInfo(tensor_type_and_shape_info, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_type_and_shape_info) { + OrtTypeInfo* output = nullptr; + ThrowOnError(GetModelEditorApi().CreateSparseTensorTypeInfo(sparse_tensor_type_and_shape_info, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateSequenceTypeInfo(ConstTypeInfo sequence_type) { + OrtTypeInfo* output; + ThrowOnError(GetModelEditorApi().CreateSequenceTypeInfo(sequence_type, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type) { + OrtTypeInfo* output; + ThrowOnError(GetModelEditorApi().CreateMapTypeInfo(key_type, value_type, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateOptionalTypeInfo(ConstTypeInfo contained_type) { + OrtTypeInfo* output; + ThrowOnError(GetModelEditorApi().CreateOptionalTypeInfo(contained_type, &output)); + return TypeInfo{output}; +} +#endif // #if !defined(ORT_MINIMAL_BUILD) + namespace detail { template @@ -1244,9 +1426,16 @@ inline void TensorTypeAndShapeInfoImpl::GetSymbolicDimensions(const char** va ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count)); } +template +inline std::vector TensorTypeAndShapeInfoImpl::GetSymbolicDimensions() const { + std::vector out(GetDimensionsCount(), nullptr); + ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, out.data(), out.size())); + return out; +} + template inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { - std::vector out(GetDimensionsCount(), 0); + std::vector out(GetDimensionsCount(), -1); ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size())); return out; } @@ -1560,23 +1749,35 @@ void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_inf } // namespace detail template -inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { +inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, + const int64_t* shape, size_t shape_len) { return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); } -inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, +inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); return Value{out}; } +inline Value Value::CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorWithDataAndDeleterAsOrtValue(deleter, p_data, p_data_byte_count, + shape, shape_len, type, &out)); + return Value{out}; +} + template inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); } -inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); return Value{out}; @@ -1594,7 +1795,8 @@ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& values_shape, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, - values_shape.shape, values_shape.shape_len, type, &out)); + values_shape.shape, values_shape.shape_len, type, + &out)); return Value{out}; } @@ -2167,4 +2369,142 @@ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) con return attr_hdl; } +namespace detail { +inline std::vector StringsToCharPtrs(const std::vector& strings) { + std::vector ptrs; + ptrs.reserve(strings.size()); + std::transform(strings.begin(), strings.end(), std::back_inserter(ptrs), + [](const std::string& s) { return s.c_str(); }); + + return ptrs; +} +} // namespace detail + +#if !defined(ORT_MINIMAL_BUILD) +// static +inline void Node::Init(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes, + OrtNode*& node) { + auto inputs = detail::StringsToCharPtrs(input_names); + auto outputs = detail::StringsToCharPtrs(output_names); + + std::vector attributes_ptrs; + attributes_ptrs.reserve(attributes.size()); + std::transform(attributes.begin(), attributes.end(), std::back_inserter(attributes_ptrs), + [](OpAttr& attr) -> OrtOpAttr* { return attr; }); + + ThrowOnError(GetModelEditorApi().CreateNode(operator_name.c_str(), operator_domain.c_str(), node_name.c_str(), + inputs.data(), inputs.size(), + outputs.data(), outputs.size(), + attributes_ptrs.data(), attributes_ptrs.size(), + &node)); + + // Node now owns the attributes + std::for_each(attributes.begin(), attributes.end(), [](OpAttr& attr) { attr.release(); }); +} + +inline Node::Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes) { + Init(operator_name, operator_domain, node_name, input_names, output_names, attributes, p_); +} + +inline Node::Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names) { + std::vector empty_attributes; + Init(operator_name, operator_domain, node_name, input_names, output_names, empty_attributes, p_); +} + +inline Graph::Graph() { + ThrowOnError(GetModelEditorApi().CreateGraph(&p_)); +} + +inline Model::Model(const std::vector& opsets) { + std::vector domains; + std::vector versions; + domains.reserve(opsets.size()); + versions.reserve(opsets.size()); + + for (const auto& pair : opsets) { + domains.push_back(pair.first.c_str()); + versions.push_back(pair.second); + } + + ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); +} + +inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { + ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_)); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +namespace detail { +template <> +inline std::string ValueInfoImpl::Name() const { + const char* name = nullptr; + ThrowOnError(GetApi().GetValueInfoName(this->p_, &name)); + return name; +} + +template <> +inline ConstTypeInfo ValueInfoImpl::TypeInfo() const { + const OrtTypeInfo* type_info = nullptr; + ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info)); + return ConstTypeInfo{type_info}; +} + +#if !defined(ORT_MINIMAL_BUILD) +template <> +inline void GraphImpl::SetInputs(std::vector& inputs) { + std::vector inputs_ptrs; + inputs_ptrs.reserve(inputs.size()); + std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs), + [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); + + ThrowOnError(GetModelEditorApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size())); + + // Graph now owns the inputs + std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); }); +} + +template <> +inline void GraphImpl::SetOutputs(std::vector& outputs) { + std::vector outputs_ptrs; + outputs_ptrs.reserve(outputs.size()); + std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs), + [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); + + ThrowOnError(GetModelEditorApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size())); + + // Graph now owns the outputs + std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); }); +} + +template <> +inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { + // Graph takes ownership of `initializer` + ThrowOnError(GetModelEditorApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external)); +} + +template <> +inline void GraphImpl::AddNode(Node& node) { + // Graph takes ownership of `node` + ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release())); +} + +template <> +inline void ModelImpl::AddGraph(Graph& graph) { + // Model takes ownership of `graph` + ThrowOnError(GetModelEditorApi().AddGraphToModel(p_, graph.release())); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +} // namespace detail } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 117a2cdabca2f..af1f9c04b2831 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -315,9 +315,12 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed // in case user need to merge/connect multiple EPContext nodes in one model static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix"; -// Share EP related resources across EPs +// Share EP related resources across sessions static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts"; +// Stop to share EP related resources across sessions from then on +static const char* const kOrtSessionOptionStopShareEpContexts = "ep.stop_share_ep_contexts"; + // Use this config when dumping EP context model with an external initializers file // All initializers will be inside the external data file if specified, otherwise all in Onnx file static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFileName = diff --git a/js/build_webgpu.bat b/js/build_webgpu.bat new file mode 100644 index 0000000000000..95413509e701d --- /dev/null +++ b/js/build_webgpu.bat @@ -0,0 +1,79 @@ +@echo off + +rem build_webgpu.bat --- build onnxruntime-web with WebGPU EP +rem +rem Usage: +rem build_webgpu.bat config [clean] +rem +rem Options: +rem config Build configuration, "d" or "r" +rem clean Perform a clean build, "clean" or empty + +setlocal enabledelayedexpansion + +set ROOT=%~dp0..\ +set BUILD_DIR=%ROOT%build_webgpu + +:arg1 +if ["%~1"]==["d"] ( + set CONFIG=Debug + set CONFIG_EXTRA_FLAG= + @rem --enable_wasm_profiling --wasm_run_tests_in_browser + @rem --cmake_extra_defines onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL=1 + @rem --enable_wasm_debug_info + goto :arg2 +) +if ["%~1"]==["r"] ( + set CONFIG=Release + set CONFIG_EXTRA_FLAG= + @rem --enable_wasm_api_exception_catching --disable_rtti + goto :arg2 +) +echo Invalid configuration "%~1", must be "d"(Debug) or "r"(Release) +exit /b 1 + +:arg2 +if ["%~2"]==["clean"] ( + goto :clean +) +if not exist "%ROOT%js\web\dist" ( + goto :npm_ci +) + +goto :build_wasm + +:clean +if exist "%BUILD_DIR%" ( + rd /s /q %BUILD_DIR% +) + +pushd %ROOT% +git submodule sync --recursive +git submodule update --init --recursive +popd + +:npm_ci +pushd %ROOT%js +call npm ci +popd +pushd %ROOT%js\common +call npm ci +popd +pushd %ROOT%js\web +call npm ci +call npm run pull:wasm +popd + +:build_wasm + +set PATH=C:\Program Files\Git\usr\bin;%PATH% + +call %ROOT%build.bat --config %CONFIG% %CONFIG_EXTRA_FLAG% --skip_submodule_sync --build_wasm --target onnxruntime_webassembly --skip_tests^ + --enable_wasm_simd --enable_wasm_threads --use_jsep --use_webnn --use_webgpu --build_dir %BUILD_DIR% + +IF NOT "%ERRORLEVEL%" == "0" ( + exit /b %ERRORLEVEL% +) + +copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.wasm %ROOT%js\web\dist\ +copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.mjs %ROOT%js\web\dist\ diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index 14dbdca707220..58f4cc6281b09 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -44,12 +44,6 @@ export const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map { isTypedArrayChecked = true; const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from; const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from; + + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array = (globalThis as any).Float16Array; const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from; if (isBigInt64ArrayAvailable) { diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 8feb8d7205fa1..2c54bdbfb6874 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -261,6 +261,13 @@ export class Tensor implements TensorInterface { } else { throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`); } + } else if (arg0 === 'float16' && arg1 instanceof Uint16Array && typedArrayConstructor !== Uint16Array) { + // when Float16Array is available and data is of type Uint16Array. + // We allow Uint16Array to be passed in as data for 'float16' tensor until Float16Array is generally + // supported in JavaScript environment. + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + data = new (globalThis as any).Float16Array(arg1.buffer, arg1.byteOffset, arg1.length); } else { throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`); } diff --git a/js/common/package.json b/js/common/package.json index 3d8d3f6533cfe..2d331bb42e4c7 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -15,7 +15,8 @@ "build": "node ./build.js", "prepare": "npm run build", "pretest": "tsc --build ./test", - "test": "mocha ./test/**/*.js --timeout 30000" + "test": "mocha \"./test/**/*.js\" --timeout 30000", + "test:f16": "mocha -n js-float16array \"./test/**/*.js\" --timeout 30000" }, "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/common/test/unit-tests/common.ts b/js/common/test/unit-tests/common.ts index 0a6e4e5dd6ebd..bbbceed605bd4 100644 --- a/js/common/test/unit-tests/common.ts +++ b/js/common/test/unit-tests/common.ts @@ -29,9 +29,10 @@ export const NUMBER_COMPATIBLE_NUMERICAL_TYPES = [ export const BIGINT_TYPES = [['int64', BigInt64Array, true] as const, ['uint64', BigUint64Array, true] as const]; /** - * float16 type, data represented by Uint16Array + * float16 type, data represented by Uint16Array/Float16Array */ -export const FLOAT16_TYPE = ['float16', Uint16Array, false] as const; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export const FLOAT16_TYPE = ['float16', (globalThis as any).Float16Array ?? Uint16Array, false] as const; /** * A list of all numerical types. diff --git a/js/common/test/unit-tests/tensor/constructor-f16.ts b/js/common/test/unit-tests/tensor/constructor-f16.ts new file mode 100644 index 0000000000000..38c6ac037c5f9 --- /dev/null +++ b/js/common/test/unit-tests/tensor/constructor-f16.ts @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import assert from 'assert/strict'; +import { Tensor } from 'onnxruntime-common'; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const globalF16 = (globalThis as any).Float16Array; + +(globalF16 ? describe : describe.skip)('Tensor Constructor Tests - check type float16 (Float16Array available)', () => { + it("[float16] new Tensor('float16', numbers, dims): allow number array when Float16Array is available", () => { + const tensor = new Tensor('float16', [1, 2, 3, 4], [2, 2]); + assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); + assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'"); + assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1'); + assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2'); + assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3'); + assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4'); + assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4'); + }); + + it("[float16] new Tensor('float16', float16array, dims): allow Float16Array when Float16Array is available", () => { + const tensor = new Tensor('float16', new globalF16([1, 2, 3, 4]), [2, 2]); + assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); + assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'"); + assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1'); + assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2'); + assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3'); + assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4'); + assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4'); + }); + + it("[float16] new Tensor('float16', uint16array, dims): allow Uint16Array when Float16Array is available", () => { + const tensor = new Tensor('float16', new Uint16Array([15360, 16384, 16896, 17408]), [2, 2]); + assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); + assert(tensor.data instanceof globalF16, "tensor.data should be an instance of 'Float16Array'"); + assert.equal(tensor.data[0], 1, 'tensor.data[0] should be 1'); + assert.equal(tensor.data[1], 2, 'tensor.data[1] should be 2'); + assert.equal(tensor.data[2], 3, 'tensor.data[2] should be 3'); + assert.equal(tensor.data[3], 4, 'tensor.data[3] should be 4'); + assert.equal(tensor.data.length, 4, 'tensor.data.length should be 4'); + }); +}); + +(globalF16 ? describe.skip : describe)( + 'Tensor Constructor Tests - check type float16 (Float16Array not available)', + () => { + it( + "[float16] new Tensor('float16', numbers, dims): " + + "expect to throw because it's not allowed to construct 'float16' tensor from number array", + () => { + assert.throws(() => new Tensor('float16', [1, 2, 3, 4], [2, 2]), TypeError); + }, + ); + + it("[float16] new Tensor('float16', uint16array, dims): allow Uint16Array", () => { + const tensor = new Tensor('float16', new Uint16Array([15360, 16384, 16896, 17408]), [2, 2]); + assert.equal(tensor.type, 'float16', "tensor.type should be 'float16'"); + assert(tensor.data instanceof Uint16Array, "tensor.data should be an instance of 'Uint16Array'"); + }); + }, +); diff --git a/js/common/test/unit-tests/tensor/constructor-type.ts b/js/common/test/unit-tests/tensor/constructor-type.ts index 02390800e8611..d86e18ba744b8 100644 --- a/js/common/test/unit-tests/tensor/constructor-type.ts +++ b/js/common/test/unit-tests/tensor/constructor-type.ts @@ -105,14 +105,6 @@ describe('Tensor Constructor Tests - check types', () => { assert(tensor.data instanceof Uint8Array, "tensor.data should be an instance of 'Uint8Array'"); }); - it( - "[float16] new Tensor('float16', numbers, dims): " + - "expect to throw because it's not allowed to construct 'float16' tensor from number array", - () => { - assert.throws(() => new Tensor('float16', [1, 2, 3, 4], [2, 2]), TypeError); - }, - ); - it("[badtype] new Tensor('a', numbers, dims): expect to throw because 'a' is an invalid type", () => { assert.throws(() => new TensorAny('a', [1, 2, 3, 4], [2, 2]), TypeError); }); diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 59f64a3179605..83a52ebaefe05 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -40,6 +40,13 @@ interface BuildDefinitions { */ readonly ENABLE_BUNDLE_WASM_JS: boolean; + /** + * defines whether to use WebGPU EP instead of JSEP for WebGPU backend. + * + * This flag requires the corresponding WebAssembly artifact to be built with `--use_webgpu` flag. + */ + readonly USE_WEBGPU_EP: boolean; + // #endregion // #region Build definitions for ESM diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a0010df4643a4..413e89111740e 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -13,7 +13,6 @@ import { ProgramManager } from './webgpu/program-manager'; import { AdapterInfo, ComputeContext, - DeviceInfo, GpuArchitecture, GpuData, GpuVendor, @@ -135,26 +134,6 @@ class AdapterInfoImpl implements AdapterInfo { } } -class DeviceInfoImpl implements DeviceInfo { - readonly subgroupsSupported: boolean; - readonly subgroupsF16Supported: boolean; - readonly subgroupSizeRange?: readonly [number, number]; - - constructor(device: GPUDevice) { - this.subgroupsSupported = device.features.has('subgroups' as GPUFeatureName); - this.subgroupsF16Supported = device.features.has('subgroups' as GPUFeatureName); - // Currently subgroups feature is still experimental and size attributes are not in the WebGPU IDL, so we have to - // workaround the IDL type checks. - // TODO: clean this after subgroups feature is settled in IDL. - const deviceSubgroupsLimits = device.limits as { minSubgroupSize?: number; maxSubgroupSize?: number }; - if (!this.subgroupsSupported || !deviceSubgroupsLimits.minSubgroupSize || !deviceSubgroupsLimits.maxSubgroupSize) { - this.subgroupSizeRange = undefined; - } else { - this.subgroupSizeRange = [deviceSubgroupsLimits.minSubgroupSize, deviceSubgroupsLimits.maxSubgroupSize]; - } - } -} - /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. @@ -162,7 +141,6 @@ class DeviceInfoImpl implements DeviceInfo { export class WebGpuBackend { adapterInfo: AdapterInfoImpl; device: GPUDevice; - deviceInfo: DeviceInfoImpl; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping */ @@ -274,13 +252,9 @@ export class WebGpuBackend { } requireFeatureIfAvailable('shader-f16'); // Try subgroups - if (requireFeatureIfAvailable('subgroups' as GPUFeatureName)) { - // If subgroups feature is available, also try subgroups-f16 - requireFeatureIfAvailable('subgroups-f16' as GPUFeatureName); - } + requireFeatureIfAvailable('subgroups' as GPUFeatureName); this.device = await adapter.requestDevice(deviceDescriptor); - this.deviceInfo = new DeviceInfoImpl(this.device); this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo())); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 2b9a9208e2e53..55784ae13ad7a 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -314,7 +314,8 @@ export class WebNNBackend { bufferView = new Float32Array(buffer); break; case 'float16': - bufferView = new Uint16Array(buffer); + bufferView = + typeof Float16Array !== 'undefined' && Float16Array.from ? new Float16Array(buffer) : new Uint16Array(buffer); break; case 'int32': bufferView = new Int32Array(buffer); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index b4071eae51c8f..8ab6b054bf8a7 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -1,23 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { Env } from 'onnxruntime-common'; +import type { Env } from 'onnxruntime-common'; import { calculateTensorSizeInBytes, DataType } from '../wasm-common'; import type { OrtWasmModule } from '../wasm-types'; -import { WebGpuBackend } from './backend-webgpu'; +import type { WebGpuBackend } from './backend-webgpu'; import { LOG_DEBUG } from './log'; -import { TensorView } from './tensor-view'; +import type { TensorView } from './tensor-view'; import { ShapeUtil } from './util'; -import { - AdapterInfo, - ComputeContext, - ComputeContextInputsOutputsMapping, - DeviceInfo, - ProgramInfo, -} from './webgpu/types'; +import type { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types'; import { WebNNBackend } from './backend-webnn'; /* eslint-disable no-bitwise */ @@ -76,7 +70,6 @@ class TensorViewImpl implements TensorView { class ComputeContextImpl implements ComputeContext { readonly adapterInfo: AdapterInfo; - readonly deviceInfo: DeviceInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -94,7 +87,6 @@ class ComputeContextImpl implements ComputeContext { contextDataOffset: number, ) { this.adapterInfo = backend.adapterInfo; - this.deviceInfo = backend.deviceInfo; // extract context data const ptrSize = module.PTR_SIZE; @@ -205,79 +197,83 @@ export const init = async ( } if (name === 'webgpu') { - const backend = new WebGpuBackend(); - await backend.initialize(env, gpuAdapter!); + if (!BUILD_DEFS.USE_WEBGPU_EP) { + // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + const webGpuBackendImpl = require('./backend-webgpu').WebGpuBackend; + const backend = new webGpuBackendImpl(); + await backend.initialize(env, gpuAdapter!); - jsepInit('webgpu', [ - // backend - backend, + jsepInit('webgpu', [ + // backend + backend, + + // jsepAlloc() + (size: number) => backend.alloc(Number(size)), - // jsepAlloc() - (size: number) => backend.alloc(Number(size)), + // jsepFree() + (ptr: number) => backend.free(ptr), - // jsepFree() - (ptr: number) => backend.free(ptr), + // jsepCopy(src, dst, size, isSourceGpu) + (src: number, dst: number, size: number, isSourceGpu = false) => { + if (isSourceGpu) { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, + ); + backend.memcpy(Number(src), Number(dst)); + } else { + LOG_DEBUG( + 'verbose', + () => + `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, + ); + const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); + backend.upload(Number(dst), data); + } + }, - // jsepCopy(src, dst, size, isSourceGpu) - (src: number, dst: number, size: number, isSourceGpu = false) => { - if (isSourceGpu) { + // jsepCopyAsync(src, dst, size) + async (gpuDataId: number, dataOffset: number, size: number): Promise => { LOG_DEBUG( 'verbose', - () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, + () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, ); - backend.memcpy(Number(src), Number(dst)); - } else { - LOG_DEBUG( - 'verbose', - () => - `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, - ); - const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); - backend.upload(Number(dst), data); - } - }, - // jsepCopyAsync(src, dst, size) - async (gpuDataId: number, dataOffset: number, size: number): Promise => { - LOG_DEBUG( - 'verbose', - () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, - ); - - await backend.download(Number(gpuDataId), () => - module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), - ); - }, + await backend.download(Number(gpuDataId), () => + module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), + ); + }, - // jsepCreateKernel - (kernelType: string, kernelId: number, attribute: unknown) => - backend.createKernel( - kernelType, - Number(kernelId), - attribute, - module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), - ), + // jsepCreateKernel + (kernelType: string, kernelId: number, attribute: unknown) => + backend.createKernel( + kernelType, + Number(kernelId), + attribute, + module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), + ), - // jsepReleaseKernel - (kernel: number) => backend.releaseKernel(kernel), + // jsepReleaseKernel + (kernel: number) => backend.releaseKernel(kernel), - // jsepRun - (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { - LOG_DEBUG( - 'verbose', - () => - `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, - ); - const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); - return backend.computeKernel(Number(kernel), context, errors); - }, - // jsepCaptureBegin - () => backend.captureBegin(), - // jsepCaptureEnd - () => backend.captureEnd(), - // jsepReplay - () => backend.replay(), - ]); + // jsepRun + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { + LOG_DEBUG( + 'verbose', + () => + `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, + ); + const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); + return backend.computeKernel(Number(kernel), context, errors); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay(), + ]); + } } else { const backend = new WebNNBackend(env); jsepInit('webnn', [ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ad1de42106d6d..50620cea33863 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -46,6 +46,11 @@ export const createConvTranspose2DProgramInfo = ( const inputChannelsPerGroup = wShape[2] / group; const outputChannelsPerGroup = wShape[3]; const aComponents = isChannelsLast ? getMaxComponents(inputChannelsPerGroup) : 1; + const packInputAs4 = isChannelsLast && outputChannelsPerGroup === 1 && inputChannelsPerGroup >= 4; + const inputChannelsPerGroupInt = packInputAs4 + ? Math.floor(inputChannelsPerGroup / 4) * 4 + : Math.floor(inputChannelsPerGroup / aComponents) * aComponents; + const inputChannelsRemainder = inputChannelsPerGroup - inputChannelsPerGroupInt; const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1; const bComponents = isChannelsLast ? (outputChannelsPerGroup === 1 ? aComponents : components) : 1; const outputSize = ShapeUtil.size(outputShape) / components; @@ -78,6 +83,7 @@ export const createConvTranspose2DProgramInfo = ( { type: DataType.uint32, data: dilations }, { type: DataType.uint32, data: effectiveFilterDims }, { type: DataType.int32, data: pads }, + { type: DataType.uint32, data: inputChannelsPerGroupInt }, { type: DataType.uint32, data: inputChannelsPerGroup }, { type: DataType.uint32, data: outputChannelsPerGroup }, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims), @@ -96,6 +102,7 @@ export const createConvTranspose2DProgramInfo = ( { name: 'dilations', type: 'u32', length: filterDims.length }, { name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length }, { name: 'pads', type: 'i32', length: pads.length }, + { name: 'input_channels_per_group_int', type: 'u32' }, { name: 'input_channels_per_group', type: 'u32' }, { name: 'output_channels_per_group', type: 'u32' }, ]; @@ -114,16 +121,40 @@ export const createConvTranspose2DProgramInfo = ( const calculateResult = (): string => { let calcStr = ''; - if (aComponents === 1) { - calcStr += ` - let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)}; - let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)}; - dotProd = dotProd + xValue * wValue;`; + if (packInputAs4) { + if (aComponents === 4) { + calcStr += ` + let xValue = ${dy.getByOffset('x_offset')}; + let wValue = ${w.getByOffset('w_offset')}; + dotProd = dotProd + dot(xValue, wValue); + x_offset += 1u; + w_offset += 1u;`; + } else if (aComponents === 2) { + calcStr += ` + dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')})); + x_offset += 2u; + w_offset += 2u;`; + } else if (aComponents === 1) { + calcStr += ` + dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}, ${dy.getByOffset('x_offset + 2u')}, ${dy.getByOffset('x_offset + 3u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}, ${w.getByOffset('w_offset + 2u')}, ${w.getByOffset('w_offset + 3u')})); + x_offset += 4u; + w_offset += 4u;`; + } } else { - if (outputChannelsPerGroup === 1) { + calcStr += ` + let xValue = ${ + isChannelsLast + ? dy.getByOffset( + `${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`, + ) + : dy.get('batch', 'inputChannel', 'idyR', 'idyC') + }; + `; + if (aComponents === 1) { calcStr += ` - let wValue = ${w.getByOffset(`${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)} / ${bComponents}`)}; - dotProd = dotProd + dot(xValue, wValue);`; + let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)}; + let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)}; + dotProd = dotProd + xValue * wValue;`; } else { for (let c = 0; c < aComponents; c++) { calcStr += ` @@ -134,6 +165,32 @@ export const createConvTranspose2DProgramInfo = ( } return calcStr; }; + const calculateRemainder = (): string => { + if (inputChannelsRemainder === 0) { + return ''; + } + if (!packInputAs4) { + throw new Error(`packInputAs4 ${packInputAs4} is not true.`); + } + let calcStr = ''; + if (aComponents === 1) { + calcStr += 'dotProd = dotProd'; + for (let i = 0; i < inputChannelsRemainder; i++) { + calcStr += ` + + ${dy.getByOffset(`x_offset + ${i}`)} * ${w.getByOffset(`w_offset + ${i}`)}`; + } + calcStr += ';'; + } else if (aComponents === 2) { + if (inputChannelsRemainder !== 2) { + throw new Error(`Invalid inputChannelsRemainder ${inputChannelsRemainder}.`); + } + calcStr += ` + let xValue = ${dy.getByOffset('x_offset')}; + let wValue = ${w.getByOffset('w_offset')}; + dotProd = dotProd + dot(xValue, wValue);`; + } + return calcStr; + }; const codeSnippet = ` let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)}; let batch = ${output.indicesGet('outputIndices', 0)}; @@ -169,7 +226,6 @@ export const createConvTranspose2DProgramInfo = ( // Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0 wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner); } - for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { if (wC % uniforms.dilations.y != 0) { continue; @@ -182,17 +238,19 @@ export const createConvTranspose2DProgramInfo = ( } let idyC: u32 = u32(dyC); var inputChannel = groupId * uniforms.input_channels_per_group; - for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + ${aComponents}) { - let xValue = ${ - isChannelsLast - ? dy.getByOffset( - `${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`, - ) - : dy.get('batch', 'inputChannel', 'idyR', 'idyC') - }; + ${ + packInputAs4 + ? ` + var x_offset = ${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}; + var w_offset = ${w.indicesToOffset(`${w.type.indices}(wRPerm, wCPerm, inputChannel, wOutChannel)`)} / ${bComponents}; + ` + : '' + } + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group_int; d2 = d2 + ${packInputAs4 ? 4 : aComponents}) { ${calculateResult()} - inputChannel = inputChannel + ${aComponents}; + inputChannel = inputChannel + ${packInputAs4 ? 4 : aComponents}; } + ${calculateRemainder()} wC = wC + uniforms.strides.y - 1; } wR = wR + uniforms.strides[0] - 1; @@ -211,7 +269,7 @@ export const createConvTranspose2DProgramInfo = ( return { name: 'ConvTranspose2D', shaderCache: { - hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${outputChannelsPerGroup === 1}`, + hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${packInputAs4}${inputChannelsRemainder}`, inputDependencies, }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 6a78c8ae3b190..6a8dffb73fa08 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -433,7 +433,7 @@ const createInPlaceSoftmaxProgramInfo = ( getShaderSource, getRunData: () => ({ outputs: [], - dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads }, + dispatchGroup: { x: 1, y: sequenceLength, z: batchSize * numHeads }, programUniforms, }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 2c5180c5db3ee..18d505f57655a 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -99,7 +99,6 @@ export class ProgramManager { const extensionsInfo: Array<{ feature: GPUFeatureName; extension: string }> = [ { feature: 'shader-f16', extension: 'f16' }, { feature: 'subgroups' as GPUFeatureName, extension: 'subgroups' }, - { feature: 'subgroups-f16' as GPUFeatureName, extension: 'subgroups_f16' }, ]; extensionsInfo.forEach((info) => { if (device.features.has(info.feature)) { diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 9321ac170d036..f3cfc6cb98cae 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -21,11 +21,6 @@ export interface AdapterInfo { isArchitecture: (architecture: GpuArchitecture) => boolean; isVendor: (vendor: GpuVendor) => boolean; } -export interface DeviceInfo { - readonly subgroupsSupported: boolean; - readonly subgroupsF16Supported: boolean; - readonly subgroupSizeRange?: readonly [number, number]; -} export interface GpuData { type: GpuDataType; @@ -165,11 +160,6 @@ export interface ComputeContext { */ readonly adapterInfo: AdapterInfo; - /** - * gpu device info - */ - readonly deviceInfo: DeviceInfo; - /** * stores the pointer to OpKernelContext */ diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 5d97bb83e3475..30b1f5101e5f2 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -12,7 +12,11 @@ import { } from './proxy-messages'; import * as core from './wasm-core-impl'; import { initializeWebAssembly } from './wasm-factory'; -import { importProxyWorker, inferWasmPathPrefixFromScriptSrc } from './wasm-utils-import'; +import { + importProxyWorker, + inferWasmPathPrefixFromScriptSrc, + isEsmImportMetaUrlHardcodedAsFileUri, +} from './wasm-utils-import'; const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined'; let proxyWorker: Worker | undefined; @@ -116,7 +120,7 @@ export const initializeWebAssemblyAndOrtRuntime = async (): Promise => { BUILD_DEFS.IS_ESM && BUILD_DEFS.ENABLE_BUNDLE_WASM_JS && !message.in!.wasm.wasmPaths && - (objectUrl || BUILD_DEFS.ESM_IMPORT_META_URL?.startsWith('file:')) + (objectUrl || isEsmImportMetaUrlHardcodedAsFileUri) ) { // for a build bundled the wasm JS, if either of the following conditions is met: // - the proxy worker is loaded from a blob URL diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 17e564247863d..89a4484e5a1c4 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { InferenceSession } from 'onnxruntime-common'; +import type { InferenceSession } from 'onnxruntime-common'; import { getInstance } from './wasm-factory'; import { allocWasmString, checkLastError, iterateExtraOptions } from './wasm-utils'; @@ -54,13 +54,28 @@ const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => } }; -const setExecutionProviders = ( +const appendSessionConfig = (sessionOptionsHandle: number, key: string, value: string, allocs: number[]): void => { + const keyDataOffset = allocWasmString(key, allocs); + const valueDataOffset = allocWasmString(value, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError(`Can't set a session config entry: ${key} - ${value}.`); + } +}; + +const appendEpOption = (epOptions: Array<[number, number]>, key: string, value: string, allocs: number[]): void => { + const keyDataOffset = allocWasmString(key, allocs); + const valueDataOffset = allocWasmString(value, allocs); + epOptions.push([keyDataOffset, valueDataOffset]); +}; + +const setExecutionProviders = async ( sessionOptionsHandle: number, executionProviders: readonly InferenceSession.ExecutionProviderConfig[], allocs: number[], -): void => { +): Promise => { for (const ep of executionProviders) { let epName = typeof ep === 'string' ? ep : ep.name; + const epOptions: Array<[number, number]> = []; // check EP name switch (epName) { @@ -71,26 +86,44 @@ const setExecutionProviders = ( // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; if (deviceType) { - const keyDataOffset = allocWasmString('deviceType', allocs); - const valueDataOffset = allocWasmString(deviceType, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); - } + appendSessionConfig(sessionOptionsHandle, 'deviceType', deviceType, allocs); } } break; case 'webgpu': - epName = 'JS'; - if (typeof ep !== 'string') { - const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; - if (webgpuOptions?.preferredLayout) { - if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { - throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); + if (BUILD_DEFS.USE_WEBGPU_EP) { + epName = 'WebGPU'; + let customDevice: GPUDevice | undefined; + + if (typeof ep !== 'string') { + const customOptions = ep as unknown as { device: GPUDevice }; + if (customOptions.device) { + if (typeof GPUDevice !== 'undefined' && customOptions.device instanceof GPUDevice) { + customDevice = customOptions.device; + } else { + throw new Error('Invalid GPU device set in WebGPU EP options.'); + } } - const keyDataOffset = allocWasmString('preferredLayout', allocs); - const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); + + // TODO: handle more options + } + + const info = getInstance().webgpuRegisterDevice!(customDevice); + if (info) { + const [deviceId, instanceHandle, deviceHandle] = info; + appendEpOption(epOptions, 'deviceId', deviceId.toString(), allocs); + appendEpOption(epOptions, 'webgpuInstance', instanceHandle.toString(), allocs); + appendEpOption(epOptions, 'webgpuDevice', deviceHandle.toString(), allocs); + } + } else { + epName = 'JS'; + if (typeof ep !== 'string') { + const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; + if (webgpuOptions?.preferredLayout) { + if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { + throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); + } + appendSessionConfig(sessionOptionsHandle, 'preferredLayout', webgpuOptions.preferredLayout, allocs); } } } @@ -103,13 +136,34 @@ const setExecutionProviders = ( } const epNameDataOffset = allocWasmString(epName, allocs); - if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) { + const epOptionsCount = epOptions.length; + let keysOffset = 0; + let valuesOffset = 0; + if (epOptionsCount > 0) { + keysOffset = getInstance()._malloc(epOptionsCount * getInstance().PTR_SIZE); + allocs.push(keysOffset); + valuesOffset = getInstance()._malloc(epOptionsCount * getInstance().PTR_SIZE); + allocs.push(valuesOffset); + for (let i = 0; i < epOptionsCount; i++) { + getInstance().setValue(keysOffset + i * getInstance().PTR_SIZE, epOptions[i][0], '*'); + getInstance().setValue(valuesOffset + i * getInstance().PTR_SIZE, epOptions[i][1], '*'); + } + } + if ( + (await getInstance()._OrtAppendExecutionProvider( + sessionOptionsHandle, + epNameDataOffset, + keysOffset, + valuesOffset, + epOptionsCount, + )) !== 0 + ) { checkLastError(`Can't append execution provider: ${epName}.`); } } }; -export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => { +export const setSessionOptions = async (options?: InferenceSession.SessionOptions): Promise<[number, number[]]> => { const wasm = getInstance(); let sessionOptionsHandle = 0; const allocs: number[] = []; @@ -155,20 +209,19 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n } if (sessionOptions.executionProviders) { - setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); + await setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } if (sessionOptions.enableGraphCapture !== undefined) { if (typeof sessionOptions.enableGraphCapture !== 'boolean') { throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); } - const keyDataOffset = allocWasmString('enableGraphCapture', allocs); - const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); - if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError( - `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`, - ); - } + appendSessionConfig( + sessionOptionsHandle, + 'enableGraphCapture', + sessionOptions.enableGraphCapture.toString(), + allocs, + ); } if (sessionOptions.freeDimensionOverrides) { @@ -188,12 +241,7 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n if (sessionOptions.extra !== undefined) { iterateExtraOptions(sessionOptions.extra, '', new WeakSet>(), (key, value) => { - const keyDataOffset = allocWasmString(key, allocs); - const valueDataOffset = allocWasmString(value, allocs); - - if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: ${key} - ${value}.`); - } + appendSessionConfig(sessionOptionsHandle, key, value, allocs); }); } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 4bccfa76fdda3..dbcf80adf3552 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -102,11 +102,20 @@ export const initRuntime = async (env: Env): Promise => { * @param epName */ export const initEp = async (env: Env, epName: string): Promise => { + // initialize ASYNCIFY support + getInstance().asyncInit?.(); + + if (epName === 'webgpu' && BUILD_DEFS.USE_WEBGPU_EP) { + getInstance().webgpuInit!((device) => { + env.webgpu.device = device; + }); + } + if (!BUILD_DEFS.DISABLE_JSEP) { // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires const initJsep = require('./jsep/init').init; - if (epName === 'webgpu') { + if (epName === 'webgpu' && !BUILD_DEFS.USE_WEBGPU_EP) { // perform WebGPU availability check if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); @@ -270,7 +279,7 @@ export const createSession = async ( const outputNamesUTF8Encoded = []; try { - [sessionOptionsHandle, allocs] = setSessionOptions(options); + [sessionOptionsHandle, allocs] = await setSessionOptions(options); if (options?.externalData && wasm.mountExternalData) { const loadingPromises = []; @@ -278,7 +287,7 @@ export const createSession = async ( const path = typeof file === 'string' ? file : file.path; loadingPromises.push( loadFile(typeof file === 'string' ? file : file.data).then((data) => { - wasm.mountExternalData!(path, data); + wasm.mountExternalData(path, data); }), ); } @@ -312,6 +321,7 @@ export const createSession = async ( } sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); + wasm.webgpuOnCreateSession?.(sessionHandle); if (sessionHandle === 0) { checkLastError("Can't create a session."); } @@ -444,6 +454,7 @@ export const releaseSession = (sessionId: number): void => { } wasm.jsepOnReleaseSession?.(sessionId); + wasm.webgpuOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); @@ -491,11 +502,20 @@ export const prepareInputOutputTensor = async ( const gpuBuffer = tensor[2].gpuBuffer; dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; - const registerBuffer = wasm.jsepRegisterBuffer; - if (!registerBuffer) { - throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); + if (BUILD_DEFS.USE_WEBGPU_EP) { + const registerBuffer = wasm.webgpuRegisterBuffer; + if (!registerBuffer) { + throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); + } + + rawData = registerBuffer(gpuBuffer, sessionId); + } else { + const registerBuffer = wasm.jsepRegisterBuffer; + if (!registerBuffer) { + throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); + } + rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); } - rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); } else if (location === 'ml-tensor') { const mlTensor = tensor[2].mlTensor as MLTensor; dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; @@ -791,7 +811,7 @@ export const run = async ( // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU // tensor for it. There is no mapping GPU buffer for an empty tensor. if (preferredLocation === 'gpu-buffer' && size > 0) { - const getBuffer = wasm.jsepGetBuffer; + const getBuffer = BUILD_DEFS.USE_WEBGPU_EP ? wasm.webgpuGetBuffer : wasm.jsepGetBuffer; if (!getBuffer) { throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.'); } @@ -804,20 +824,43 @@ export const run = async ( // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; - output.push([ - type, - dims, - { - gpuBuffer, - download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type), - dispose: () => { - if (wasm._OrtReleaseTensor(tensor) !== 0) { - checkLastError("Can't release tensor."); - } + if (BUILD_DEFS.USE_WEBGPU_EP) { + wasm.webgpuRegisterBuffer!(gpuBuffer, sessionId, dataOffset); + const downloadDataFunction = wasm.webgpuCreateDownloader!(gpuBuffer, bufferSize, sessionId); + output.push([ + type, + dims, + { + gpuBuffer, + download: async () => { + const arrayBuffer = await downloadDataFunction(); + const data = new (tensorTypeToTypedArrayConstructor(type!))(arrayBuffer); + return data as Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]; + }, + dispose: () => { + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError("Can't release tensor."); + } + }, }, - }, - 'gpu-buffer', - ]); + 'gpu-buffer', + ]); + } else { + output.push([ + type, + dims, + { + gpuBuffer, + download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type), + dispose: () => { + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError("Can't release tensor."); + } + }, + }, + 'gpu-buffer', + ]); + } } else if (preferredLocation === 'ml-tensor' && size > 0) { const ensureTensor = wasm.jsepEnsureTensor; if (!ensureTensor) { @@ -887,6 +930,18 @@ export const run = async ( } finally { wasm.stackRestore(beforeRunStack); + if (BUILD_DEFS.USE_WEBGPU_EP) { + inputTensors.forEach((t) => { + if (t && t[3] === 'gpu-buffer') { + wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer); + } + }); + outputTensors.forEach((t) => { + if (t && t[3] === 'gpu-buffer') { + wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer); + } + }); + } inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); inputOutputAllocs.forEach((p) => wasm._free(p)); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index b4871e145f4d7..9b2ec71fd351d 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -41,18 +41,6 @@ export declare namespace JSEP { type DownloadTensorFunction = (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; export interface Module extends WebGpuModule, WebNnModule { - /** - * Mount the external data file to an internal map, which will be used during session initialization. - * - * @param externalDataFilePath - specify the relative path of the external data file. - * @param externalDataFileData - specify the content data. - */ - mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void; - /** - * Unmount all external data files from the internal map. - */ - unmountExternalData(): void; - /** * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and @@ -294,6 +282,21 @@ export declare namespace JSEP { } } +export declare namespace WebGpu { + export interface Module { + webgpuInit(setDefaultDevice: (device: GPUDevice) => void): void; + webgpuRegisterDevice( + device?: GPUDevice, + ): undefined | [deviceId: number, instanceHandle: number, deviceHandle: number]; + webgpuOnCreateSession(sessionHandle: number): void; + webgpuOnReleaseSession(sessionHandle: number): void; + webgpuRegisterBuffer(buffer: GPUBuffer, sessionHandle: number, bufferHandle?: number): number; + webgpuUnregisterBuffer(buffer: GPUBuffer): void; + webgpuGetBuffer(bufferHandle: number): GPUBuffer; + webgpuCreateDownloader(gpuBuffer: GPUBuffer, size: number, sessionHandle: number): () => Promise; + } +} + export interface OrtInferenceAPIs { _OrtInit(numThreads: number, loggingLevel: number): number; @@ -358,7 +361,13 @@ export interface OrtInferenceAPIs { logVerbosityLevel: number, optimizedModelFilePath: number, ): number; - _OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number; + _OrtAppendExecutionProvider( + sessionOptionsHandle: number, + name: number, + providerOptionsKeys: number, + providerOptionsValues: number, + numKeys: number, + ): Promise; _OrtAddFreeDimensionOverride(sessionOptionsHandle: number, name: number, dim: number): number; _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; _OrtReleaseSessionOptions(sessionOptionsHandle: number): number; @@ -373,8 +382,11 @@ export interface OrtInferenceAPIs { /** * The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten. */ -export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial { - PTR_SIZE: number; +export interface OrtWasmModule + extends EmscriptenModule, + OrtInferenceAPIs, + Partial, + Partial { // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; @@ -387,7 +399,31 @@ export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Parti stringToUTF8(str: string, offset: number, maxBytes: number): void; // #endregion + // #region ORT shared + + readonly PTR_SIZE: 4 | 8; + + /** + * Mount the external data file to an internal map, which will be used during session initialization. + * + * @param externalDataFilePath - specify the relative path of the external data file. + * @param externalDataFileData - specify the content data. + */ + mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void; + /** + * Unmount all external data files from the internal map. + */ + unmountExternalData(): void; + + /** + * This function patches the WebAssembly module to support Asyncify. This function should be called at least once + * before any ORT API is called. + */ + asyncInit?(): void; + + // #endregion + // #region config - numThreads?: number; + readonly numThreads?: number; // #endregion } diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index 871b575d71edc..a8e27f6f334bc 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -11,6 +11,39 @@ import { isNode } from './wasm-utils-env'; */ const origin = isNode || typeof location === 'undefined' ? undefined : location.origin; +/** + * Some bundlers (eg. Webpack) will rewrite `import.meta.url` to a file URL at compile time. + * + * This function checks if `import.meta.url` starts with `file:`, but using the `>` and `<` operators instead of + * `startsWith` function so that code minimizers can remove the dead code correctly. + * + * For example, if we use terser to minify the following code: + * ```js + * if ("file://hard-coded-filename".startsWith("file:")) { + * console.log(1) + * } else { + * console.log(2) + * } + * + * if ("file://hard-coded-filename" > "file:" && "file://hard-coded-filename" < "file;") { + * console.log(3) + * } else { + * console.log(4) + * } + * ``` + * + * The minified code will be: + * ```js + * "file://hard-coded-filename".startsWith("file:")?console.log(1):console.log(2),console.log(3); + * ``` + * + * (use Terser 5.39.0 with default options, https://try.terser.org/) + * + * @returns true if the import.meta.url is hardcoded as a file URI. + */ +export const isEsmImportMetaUrlHardcodedAsFileUri = + BUILD_DEFS.IS_ESM && BUILD_DEFS.ESM_IMPORT_META_URL! > 'file:' && BUILD_DEFS.ESM_IMPORT_META_URL! < 'file;'; + const getScriptSrc = (): string | undefined => { // if Nodejs, return undefined if (isNode) { @@ -26,9 +59,22 @@ const getScriptSrc = (): string | undefined => { // new URL('actual-bundle-name.js', import.meta.url).href // ``` // So that bundler can preprocess the URL correctly. - if (BUILD_DEFS.ESM_IMPORT_META_URL?.startsWith('file:')) { + if (isEsmImportMetaUrlHardcodedAsFileUri) { // if the rewritten URL is a relative path, we need to use the origin to resolve the URL. - return new URL(new URL(BUILD_DEFS.BUNDLE_FILENAME, BUILD_DEFS.ESM_IMPORT_META_URL).href, origin).href; + + // The following is a workaround for Vite. + // + // Vite uses a bundler(rollup/rolldown) that does not rewrite `import.meta.url` to a file URL. So in theory, this + // code path should not be executed in Vite. However, the bundler does not know it and it still try to load the + // following pattern: + // - `return new URL('filename', import.meta.url).href` + // + // By replacing the pattern above with the following code, we can skip the resource loading behavior: + // - `const URL2 = URL; return new URL2('filename', import.meta.url).href;` + // + // And it still works in Webpack. + const URL2 = URL; + return new URL(new URL2(BUILD_DEFS.BUNDLE_FILENAME, BUILD_DEFS.ESM_IMPORT_META_URL).href, origin).href; } return BUILD_DEFS.ESM_IMPORT_META_URL; diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 6006de62b41b6..98e61c9f87fbb 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -27,7 +27,8 @@ const args = minimist(process.argv.slice(2)); * --bundle-mode=node * Build a single ort-web bundle for nodejs. */ -const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = args['bundle-mode'] || 'prod'; +const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = + process.env.npm_config_bundle_mode || args['bundle-mode'] || 'prod'; /** * --debug @@ -41,7 +42,18 @@ const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = args['bundle-mode'] || 'pr * Enable debug mode. In this mode, esbuild metafile feature will be enabled. Full bundle analysis will be saved to a * file as JSON. */ -const DEBUG = args.debug; // boolean|'verbose'|'save' +const DEBUG = process.env.npm_config_debug || args.debug; // boolean|'verbose'|'save' + +/** + * --webgpu-ep + * --no-webgpu-ep (default) + * + * Enable or disable the use of WebGPU EP. If enabled, the WebGPU EP will be used. If disabled, the WebGPU backend will + * be used with JSEP. + * + * (temporary) This flag is used to test the WebGPU EP integration. It will be removed in the future. + */ +const USE_WEBGPU_EP = process.env.npm_config_webgpu_ep ?? args['webgpu-ep'] ?? false; /** * Root folder of the source code: `/js/` @@ -57,6 +69,7 @@ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'false', + 'BUILD_DEFS.USE_WEBGPU_EP': JSON.stringify(!!USE_WEBGPU_EP), 'BUILD_DEFS.IS_ESM': 'false', 'BUILD_DEFS.ESM_IMPORT_META_URL': 'undefined', @@ -123,13 +136,17 @@ async function minifyWasmModuleJsForBrowser(filepath: string): Promise { // ``` // with: // ``` - // new Worker(import.meta.url.startsWith('file:') - // ? new URL(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url) - // : new URL(import.meta.url), ... + // new Worker((() => { + // const URL2 = URL; + // return import.meta.url > 'file:' && import.meta.url < 'file;' + // ? new URL2(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url) + // : new URL(import.meta.url); + // })(), ... // ``` // // NOTE: this is a workaround for some bundlers that does not support runtime import.meta.url. - // TODO: in emscripten 3.1.61+, need to update this code. + // + // Check more details in the comment of `isEsmImportMetaUrlHardcodedAsFileUri()` and `getScriptSrc()` in file `lib/wasm/wasm-utils-import.ts`. // First, check if there is exactly one occurrence of "new Worker(new URL(import.meta.url)". const matches = [...contents.matchAll(/new Worker\(new URL\(import\.meta\.url\),/g)]; @@ -142,7 +159,12 @@ async function minifyWasmModuleJsForBrowser(filepath: string): Promise { // Replace the only occurrence. contents = contents.replace( /new Worker\(new URL\(import\.meta\.url\),/, - `new Worker(import.meta.url.startsWith('file:')?new URL(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url):new URL(import.meta.url),`, + `new Worker((() => { + const URL2 = URL; + return (import.meta.url > 'file:' && import.meta.url < 'file;') + ? new URL2(BUILD_DEFS.BUNDLE_FILENAME, import.meta.url) + : new URL(import.meta.url); + })(),`, ); // Use terser to minify the code with special configurations: diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index 6429845d23df9..008d58530ee36 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -348,6 +348,128 @@ } ] }, + { + "name": "ConvTranspose NHWC- group - A", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [1, 1], "type": "ints" }, + { "name": "group", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0], + "dims": [1, 2, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0], + "dims": [2, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 36, 40, 44, 48, 52, 56, 60, 64, 68], + "dims": [1, 2, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose NHWC- group - B", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + }, + { + "data": [0.125, 0.25, 0.375], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.125, 1.125, 4.125, 4.125, 3.125, 13.125, 23.125, 18.125, 15.125, 43.125, 53.125, 36.125, 18.125, 45.125, + 52.125, 32.125, 45.25, 104.25, 115.25, 66.25, 123.25, 279.25, 305.25, 172.25, 159.25, 357.25, 383.25, + 214.25, 105.25, 232.25, 247.25, 136.25, 162.375, 351.375, 370.375, 200.375, 387.375, 833.375, 875.375, + 470.375, 231.375, 494.375, 517.375, 276.375, 0.375, 0.375, 0.375, 0.375 + ], + "dims": [1, 3, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose NHWC- group - C", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 + ], + "dims": [1, 3, 3, 4], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0, 1, 4, 7, 6, 4, 16, 26, 36, 26, 20, 56, 66, 76, 50, 24, 59, 66, 73, 44, 60, 137, 148, 159, 90, 164, 368, + 394, 420, 234, 212, 472, 498, 524, 290, 140, 307, 322, 337, 184, 216, 465, 484, 503, 270, 516, 1104, 1146, + 1188, 634, 596, 1272, 1314, 1356, 722, 352, 747, 770, 793, 420 + ], + "dims": [1, 3, 4, 5], + "type": "float32" + } + ] + } + ] + }, { "name": "ConvTranspose with bias addition C", "operator": "ConvTranspose", diff --git a/js/web/test/e2e/exports/main.js b/js/web/test/e2e/exports/main.js index 8ed22a6784e7c..d8c7bbf69039f 100644 --- a/js/web/test/e2e/exports/main.js +++ b/js/web/test/e2e/exports/main.js @@ -3,7 +3,7 @@ 'use strict'; -const { runDevTest, runProdTest } = require('./test'); +const { runDevTest, runProdTest, verifyAssets } = require('./test'); const { installOrtPackages } = require('./utils'); /** @@ -29,5 +29,14 @@ module.exports = async function main(PRESERVE, PACKAGES_TO_INSTALL) { await runDevTest('vite-default', '\x1b[32m➜\x1b[39m \x1b[1mLocal\x1b[22m:', 5173); await runProdTest('vite-default', '\x1b[32m➜\x1b[39m \x1b[1mLocal\x1b[22m:', 4173); + + await verifyAssets('vite-default', async (cwd) => { + const globby = await import('globby'); + + return { + test: 'File "dist/assets/**/ort.*.mjs" should not exist', + success: globby.globbySync('dist/assets/**/ort.*.mjs', { cwd }).length === 0, + }; + }); } }; diff --git a/js/web/test/e2e/exports/test.js b/js/web/test/e2e/exports/test.js index 9c5ed745ab0b5..e2bcffea97519 100644 --- a/js/web/test/e2e/exports/test.js +++ b/js/web/test/e2e/exports/test.js @@ -121,7 +121,29 @@ async function runProdTest(testCaseName, ready, port) { await runTest(testCaseName, ['prod'], ready, 'npm run start', port); } +async function verifyAssets(testCaseName, testers) { + testers = Array.isArray(testers) ? testers : [testers]; + const wd = path.join(__dirname, 'testcases', testCaseName); + + console.log(`[${testCaseName}] Verifying assets...`); + + const testResults = []; + + try { + for (const tester of testers) { + testResults.push(await tester(wd)); + } + + if (testResults.some((r) => !r.success)) { + throw new Error(`[${testCaseName}] asset verification failed.`); + } + } finally { + console.log(`[${testCaseName}] asset verification result:`, testResults); + } +} + module.exports = { runDevTest, runProdTest, + verifyAssets, }; diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc b/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc new file mode 100644 index 0000000000000..65c14e8cb0bdd --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/bert/bias_add.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + BiasAdd, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + BiasAdd); + +Status BiasAddProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& input = shader.AddInput("input"); + const ShaderVariableHelper& bias = shader.AddInput("bias"); + const ShaderVariableHelper& residual = shader.AddInput("residual"); + const ShaderVariableHelper& output = shader.AddOutput("output"); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let value = " << input.GetByOffset("global_idx") + << " + " << bias.GetByOffset("global_idx % uniforms.channels") + << " + " << residual.GetByOffset("global_idx") << ";\n" + << output.SetByOffset("global_idx", "value"); + + return Status::OK(); +} + +static int64_t GetMaxComponents(int64_t size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +Status BiasAdd::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* bias = context.Input(1); + const auto* residual = context.Input(2); + + TensorShape input_shape = input->Shape(); + + if (input_shape.NumDimensions() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasAdd input should have 3 dimensions."); + } + + int64_t channels = input_shape[2]; + int64_t components = GetMaxComponents(channels); + channels /= components; + + TensorShape bias_shape = bias->Shape(); + if (bias_shape.NumDimensions() != 1 || bias_shape[0] != channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasAdd bias should have 1 dimension with size equal to the number of channels."); + } + + auto* output = context.Output(0, input_shape); + int64_t output_size = output->Shape().Size() / components; + + BiasAddProgram program{}; + program.AddInputs({{input}, {bias}, {residual}}) + .AddOutput({output}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {static_cast(channels)}}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_add.h b/onnxruntime/contrib_ops/webgpu/bert/bias_add.h new file mode 100644 index 0000000000000..58cc5f09f8003 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/bias_add.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class BiasAddProgram final : public Program { + public: + BiasAddProgram() : Program{"BiasAdd"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"channels", ProgramUniformVariableDataType::Uint32}); +}; + +class BiasAdd final : public WebGpuKernel { + public: + BiasAdd(const OpKernelInfo& info) : WebGpuKernel(info) {} + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index a5cae7e7f6747..29ea4f81dd5e1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -50,7 +50,7 @@ Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) c const auto* bias = context.Input(1); auto* output = context.Output(0, input->Shape()); - uint32_t data_size = gsl::narrow(output->Shape().Size()); + uint32_t data_size = onnxruntime::narrow(output->Shape().Size()); if (data_size == 0) { return Status::OK(); } @@ -60,7 +60,7 @@ Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) c int bias_components = 1; if (bias != nullptr) { - bias_size = gsl::narrow(bias->Shape().Size()); + bias_size = onnxruntime::narrow(bias->Shape().Size()); if (bias_size % 4 == 0) { bias_components = 4; bias_size = bias_size / 4; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 57ae8a7e5ba74..1e95d3d9610ff 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -98,7 +98,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, {present_value, ProgramTensorMetadataDependency::Rank, components}}) .AddIndices(valid_present_shape); - program.SetDispatchGroupSize(gsl::narrow(valid_kv_size + 63 / 64)) + program.SetDispatchGroupSize(onnxruntime::narrow(valid_kv_size + 63 / 64)) .SetWorkgroupSize(64) .CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_) .AddUniformVariables({{static_cast(valid_kv_size)}, @@ -379,7 +379,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { if (sg_size > 8) { for (var i:u32 = 0; i < qkv_head_size_vec; i++) { - var val = select(vec4(0), v_tile[capped_sg_id][i], k_start + capped_sg_id < seq_causal_length); + var val = v_tile[capped_sg_id][i]; var sum = subgroupShuffle(val, 0) * qk_1[0]; sum += subgroupShuffle(val, 1) * qk_1[1]; sum += subgroupShuffle(val, 2) * qk_1[2]; diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index bc8b7493fc916..20e1583e0da8f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -66,11 +66,11 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con const auto* sin_cache = context.Input(3); auto* output = context.Output(0, input_shape); - const auto batch_size = gsl::narrow(input->Shape()[0]); - const auto batch_stride = gsl::narrow(input_shape.SizeFromDimension(1)); - const auto sequence_length = gsl::narrow(input_shape[input_shape.NumDimensions() - 2]); + const auto batch_size = onnxruntime::narrow(input->Shape()[0]); + const auto batch_stride = onnxruntime::narrow(input_shape.SizeFromDimension(1)); + const auto sequence_length = onnxruntime::narrow(input_shape[input_shape.NumDimensions() - 2]); const auto hidden_size = batch_stride / sequence_length; - const auto half_rotary_embedding_dim = gsl::narrow(cos_cache->Shape()[1]); + const auto half_rotary_embedding_dim = onnxruntime::narrow(cos_cache->Shape()[1]); const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape @@ -85,11 +85,11 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con std::vector global_dims(rank); std::vector global_strides(rank); for (size_t j = 0; j < rank; ++j) { - global_dims[j] = gsl::narrow(global_shape[j]); - global_strides[j] = gsl::narrow(global_shape.SizeFromDimension(j + 1)); + global_dims[j] = onnxruntime::narrow(global_shape[j]); + global_strides[j] = onnxruntime::narrow(global_shape.SizeFromDimension(j + 1)); } - const auto output_size = gsl::narrow(global_shape.Size()); + const auto output_size = onnxruntime::narrow(global_shape.Size()); RotaryEmbeddingProgram program{interleaved_}; const auto input_output_strides = input_shape.NumDimensions() == 3 diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc index a1840257d734f..d5d4632c01e2a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -122,7 +122,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo } const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - const uint32_t hidden_size = gsl::narrow(x_shape[x_shape.NumDimensions() - 1]); + const uint32_t hidden_size = onnxruntime::narrow(x_shape[x_shape.NumDimensions() - 1]); const int components = GetMaxComponents(hidden_size); const bool has_input_skip_bias_sum = input_skip_bias_sum != nullptr; @@ -133,7 +133,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) - .SetDispatchGroupSize(gsl::narrow(ceil(1.0 * data_size / hidden_size))) + .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * data_size / hidden_size))) .AddUniformVariables({ {static_cast(components)}, }) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc new file mode 100644 index 0000000000000..05cbfb1f99c48 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -0,0 +1,326 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddOutput("output", ShaderUsage::UseUniform); + shader.AddOutput("scales", ShaderUsage::UseUniform); + shader.AdditionalImplementation() << R"ADDNL_FN( + fn readInput(offset: u32) -> input_a_value_t + { + if (offset > uniforms.input_size) { + return input_a_value_t(0); + } + return input_a[offset]; + } + )ADDNL_FN"; + shader.MainFunctionBody() << R"MAIN_FN( + var local_a : array, 32>; + var max_value:vec4 = vec4(0); + for (var idx:u32=0;idx<32;idx+=1) + { + local_a[idx] = readInput(workgroup_idx*32 + idx); + max_value = max(max_value, abs(local_a[idx])); + } + var scale = max(max_value.x, max_value.y); + scale = max(scale, max_value.z); + scale = max(scale, max_value.w); + for (var idx:u32=0;idx<32;idx+=1) + { + output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); + } + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. + scales[workgroup_idx] = scale/127; + )MAIN_FN"; + return Status::OK(); +} + +Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scales_a", ShaderUsage::UseUniform); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + // This shader implements co-operative matrix multiply. The key idea here is to + // assume there is a primitive for medium size matrix multiply a subgroup can perform, + // using all its lanes and pooling all its registers to keep the values in registry. + // + // The entire workgroup which has N subgroups first loads a tile into shared memory, + // Then each subgroup loads a subtile from shared memory into registers and uses + // the medium size matrix multiply primitive to perform the math. + // The values for tile/subtile size are chosen to conform to the resource limits + // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - + // therefore there are 16 subgroups and 16 lanes in each subgroup. + // K the hidden dimension is paged in from RAM at k tile size which is 64. + // All this puts the shared memory requirement slightly above 16KB. + // WebGPU limit is 16KB, output is moved to registers instead of SHM to make + // everything fit in shared memory. + // + // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with + // subgroup shuffle as a placeholder for the day the medium matrix mul primitive + // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on + // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the + // 512B of registry from each lane. + // + // The medium size matmul is implemented using dot4I8Packed, so the inputs for + // this shader require A to be int8 quantized with block size 64. B is regular + // matmulnbits input with block size 32. + + shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";"; + + shader.AdditionalImplementation() << R"ADDNL_FN( + const tile_size = 64; + const subtile_size = 16; + const tile_size_k = 32; + const vec_factor = 4; + const u32_factor = 4; + const tile_size_k_vec = 2; + + // Shared memory + var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_A : array; // 64 x 1 + var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_B : array; // 64 x 1 + + fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let a_global = a_global_base + row; + if (a_global >= uniforms.M) + { + return; + } + tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; + if (col == 0) + { + // kidx_v - covers 16 values of k + scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; + } + } + + fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let b_global = b_global_base + row; + if (b_global >= uniforms.N) + { + return; + } + + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + if (col == 0) + { + // kidx_v - each kidx_v covers 16 values of k + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; + } + } + + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t + { + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(local_sum) * scale; + } + )ADDNL_FN"; + + shader.MainFunctionBody() << R"MAIN_FN( + // During the load phase we use all 256 threads to load 64 rows of A/B. + // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. + let a_global_base = workgroup_id.x * tile_size; + let b_global_base = workgroup_id.y * tile_size; + let load_AorB = u32(local_idx/128); + let load_row = u32((local_idx%128)/2); + let load_col = u32(local_idx%2); + + // During the compute phase, we have the 64x64 tile split into + // subtiles of 16x16. We have a grid of 4x4 subtiles. + let subtile_id = u32(local_idx / subtile_size); + let subtile_idx = u32(subtile_id / 4); + let subtile_idy = u32(subtile_id % 4); + let base_A = subtile_idx * 16; + let base_B = subtile_idy * 16; + // For each subtile we have 16 threads assigned. + let a_idx = u32(local_idx % subtile_size); + + var lane_output1: vec4; + var lane_output2: vec4; + var lane_output3: vec4; + var lane_output4: vec4; + // K's vectrorization is 16 items per index. See input_a/input_b. + // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is + // k tile size is 32. In vectorized space that is 32/16 = 2. + for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) + { + // Load Phase: Populate shared memory for the workgroup. + if (load_AorB == 0) + { + loadSHMA(a_global_base, kidx_v, load_row, load_col); + } + else + { + loadSHMB(b_global_base, kidx_v, load_row, load_col); + } + workgroupBarrier(); + + // Compute phase: Perform matmul for this subtile 16 x 32 x 16. + // Step 1: Load from shared memory into registers across entire subgroup. + var own_a0: vec4 = tile_A[0][base_A + a_idx]; + var own_a1: vec4 = tile_A[1][base_A + a_idx]; + var own_scale_a: output_element_t = scale_A[base_A + a_idx]; + if (sg_size == 16) + { + var own_b0: vec4 = tile_B[0][base_B + sg_id]; + var own_b1: vec4 = tile_B[1][base_B + sg_id]; + var own_scale_b: output_element_t = scale_B[base_B + sg_id]; + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); + lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); + lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); + lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); + + lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); + lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); + lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); + lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); + + lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); + lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); + lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); + lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); + + lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); + lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); + lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); + lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); + } + else + { + // Code for other subgroup sizes, simply doesnt use subgroups at all. + // Relies on reads from single location tile_B[][base_B + col] by all + // being optimized by the hardware. + lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); + lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); + lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); + lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); + + lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); + lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); + lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); + lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); + + lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); + lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); + lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); + lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); + + lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); + lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); + lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); + lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); + } + workgroupBarrier(); + } + + let a_global = a_global_base + base_A + a_idx; + let b_global = b_global_base + base_B; + let output_idx = ((a_global) * uniforms.N + b_global)/4; + // This creates a shader requirement that uniforms.N % 16 == 0 + if (a_global < uniforms.M && b_global < uniforms.N) + { + output[output_idx] = lane_output1; + output[output_idx+1] = lane_output2; + output[output_idx+2] = lane_output3; + output[output_idx+3] = lane_output4; + } + )MAIN_FN"; + + return Status::OK(); +} + +Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + constexpr uint32_t kVec4Components = 4; + constexpr uint32_t kVec2Components = 2; + constexpr uint32_t kU32Components = 4; + + constexpr uint32_t kBlockSizeA = 128; + DP4AMatMulQuantizeProgram quantize_program; + quantize_program.SetWorkgroupSize(1); + quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); + TensorShape a_quant_shape{1, M, K / kU32Components}; + Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); + TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); + Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); + quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}}) + .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1}, + {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}}) + .AddUniformVariable({static_cast(M * K / kVec4Components)}); + ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); + + constexpr uint32_t kTileSize = 64; + TensorShape reshaped_y_shape{1, M, N / kVec4Components}; + DP4AMatMulNBitsProgram mul_program{block_size}; + mul_program.SetWorkgroupSize(256); + mul_program.SetDispatchGroupSize( + (M + kTileSize - 1) / kTileSize, + (N + kTileSize - 1) / kTileSize, 1); + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, + {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec2Components * kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}, + {static_cast(K / 8)}, + {static_cast(K / 16)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(kVec4Components)}) + .CacheHint("Block" + std::to_string(block_size)); + return context.RunProgram(mul_program); +} + +bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + uint32_t components_k, + bool has_zero_points) { + // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. + // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 + bool use_dp4a = context.Device().HasFeature(wgpu::FeatureName::Subgroups) && + context.AdapterInfo().backendType != wgpu::BackendType::Metal; + return (accuracy_level == 4 && block_size % 32 == 0 && + batch_count == 1 && components_k == 4 && K % 64 == 0 && N % 16 == 0 && + !has_zero_points && use_dp4a); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h new file mode 100644 index 0000000000000..15b86d78301ad --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class DP4AMatMulQuantizeProgram final : public Program { + public: + DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}); +}; + +class DP4AMatMulNBitsProgram final : public Program { + public: + DP4AMatMulNBitsProgram(uint32_t block_size) : Program{"DP4AMatMulNBits"}, block_size_(block_size) {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K8", ProgramUniformVariableDataType::Uint32}, + {"K16", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t block_size_; +}; + +Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y); + +bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + uint32_t components_k, + bool has_zero_points); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 28d622b2c9c33..cce10a59fbd4b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -5,6 +5,7 @@ #include "contrib_ops/webgpu/quantization/matmul_nbits.h" #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/shader_helper.h" @@ -371,7 +372,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } } else { const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); - const int output_element_number = y.NumComponents() * gsl::narrow(output_number_); + const int output_element_number = y.NumComponents() * onnxruntime::narrow(output_number_); const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; std::string offset = "workgroup_idx * " + std::to_string(output_number_); @@ -532,255 +533,6 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("scales", ShaderUsage::UseUniform); - shader.AdditionalImplementation() << R"ADDNL_FN( - fn readInput(offset: u32) -> input_a_value_t - { - if (offset > uniforms.input_size) { - return input_a_value_t(0); - } - return input_a[offset]; - } -)ADDNL_FN"; - shader.MainFunctionBody() << R"MAIN_FN( - var local_a : array, 32>; - var max_value:vec4 = vec4(0); - for (var idx:u32=0;idx<32;idx+=1) - { - local_a[idx] = readInput(workgroup_idx*32 + idx); - max_value = max(max_value, abs(local_a[idx])); - } - var scale = max(max_value.x, max_value.y); - scale = max(scale, max_value.z); - scale = max(scale, max_value.w); - for (var idx:u32=0;idx<32;idx+=1) - { - output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); - } - // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. - scales[workgroup_idx] = scale/127; -)MAIN_FN"; - return Status::OK(); -} - -Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddInput("scales_a", ShaderUsage::UseUniform); - shader.AddInput("input_b", ShaderUsage::UseUniform); - shader.AddInput("scales_b", ShaderUsage::UseUniform); - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - - // This shader implements co-operative matrix multiply. The key idea here is to - // assume there is a primitive for medium size matrix multiply a subgroup can perform, - // using all its lanes and pooling all its registers to keep the values in registry. - // - // The entire workgroup which has N subgroups first loads a tile into shared memory, - // Then each subgroup loads a subtile from shared memory into registers and uses - // the medium size matrix multiply primitive to perform the math. - // The values for tile/subtile size are chosen to conform to the resource limits - // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - - // therefore there are 16 subgroups and 16 lanes in each subgroup. - // K the hidden dimension is paged in from RAM at k tile size which is 64. - // All this puts the shared memory requirement slightly above 16KB. - // WebGPU limit is 16KB, output is moved to registers instead of SHM to make - // everything fit in shared memory. - // - // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with - // subgroup shuffle as a placeholder for the day the medium matrix mul primitive - // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on - // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the - // 512B of registry from each lane. - // - // The medium size matmul is implemented using dot4I8Packed, so the inputs for - // this shader require A to be int8 quantized with block size 64. B is regular - // matmulnbits input with block size 32. - - shader.AdditionalImplementation() << R"ADDNL_FN( - const tile_size = 64; - const subtile_size = 16; - const tile_size_k = 32; - const vec_factor = 4; - const u32_factor = 4; - const tile_size_k_vec = 2; - const block_size = 32; - - // Shared memory - var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 - var scale_A : array; // 64 x 1 - var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 - var scale_B : array; // 64 x 1 - - fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let a_global = a_global_base + row; - if (a_global >= uniforms.M) - { - return; - } - tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; - if (col == 0) - { - // kidx_v - covers 16 values of k - scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; - } - } - - fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let b_global = b_global_base + row; - if (b_global >= uniforms.N) - { - return; - } - - let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; - var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); - var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); - b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - if (col == 0) - { - // kidx_v - each kidx_v covers 16 values of k - scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2]; - } - } - - // Scaled dot product of 8 packed unsigned integers. - fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t - { - var local_sum = dot4I8Packed(a1[0], b1[0]); - local_sum += dot4I8Packed(a1[1], b1[1]); - local_sum += dot4I8Packed(a1[2], b1[2]); - local_sum += dot4I8Packed(a1[3], b1[3]); - local_sum += dot4I8Packed(a2[0], b2[0]); - local_sum += dot4I8Packed(a2[1], b2[1]); - local_sum += dot4I8Packed(a2[2], b2[2]); - local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(local_sum) * scale; - } -)ADDNL_FN"; - - shader.MainFunctionBody() << R"MAIN_FN( - // During the load phase we use all 256 threads to load 64 rows of A/B. - // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. - let a_global_base = workgroup_id.x * tile_size; - let b_global_base = workgroup_id.y * tile_size; - let load_AorB = u32(local_idx/128); - let load_row = u32((local_idx%128)/2); - let load_col = u32(local_idx%2); - - // During the compute phase, we have the 64x64 tile split into - // subtiles of 16x16. We have a grid of 4x4 subtiles. - let subtile_id = u32(local_idx / subtile_size); - let subtile_idx = u32(subtile_id / 4); - let subtile_idy = u32(subtile_id % 4); - let base_A = subtile_idx * 16; - let base_B = subtile_idy * 16; - // For each subtile we have 16 threads assigned. - let a_idx = u32(local_idx % subtile_size); - - var lane_output1: vec4; - var lane_output2: vec4; - var lane_output3: vec4; - var lane_output4: vec4; - // K's vectrorization is 16 items per index. See input_a/input_b. - // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is - // k tile size is 32. In vectorized space that is 32/16 = 2. - for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) - { - // Load Phase: Populate shared memory for the workgroup. - if (load_AorB == 0) - { - loadSHMA(a_global_base, kidx_v, load_row, load_col); - } - else - { - loadSHMB(b_global_base, kidx_v, load_row, load_col); - } - workgroupBarrier(); - - // Compute phase: Perform matmul for this subtile 16 x 32 x 16. - // Step 1: Load from shared memory into registers across entire subgroup. - var own_a0: vec4 = tile_A[0][base_A + a_idx]; - var own_a1: vec4 = tile_A[1][base_A + a_idx]; - var own_scale_a: output_element_t = scale_A[base_A + a_idx]; - if (sg_size == 16) - { - var own_b0: vec4 = tile_B[0][base_B + sg_id]; - var own_b1: vec4 = tile_B[1][base_B + sg_id]; - var own_scale_b: output_element_t = scale_B[base_B + sg_id]; - // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. - lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); - lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); - lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); - lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); - - lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); - lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); - lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); - lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); - - lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); - lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); - lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); - lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); - - lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); - lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); - lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); - lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); - } - else - { - // Code for other subgroup sizes, simply doesnt use subgroups at all. - // Relies on reads from single location tile_B[][base_B + col] by all - // being optimized by the hardware. - lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); - lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); - lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); - lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); - - lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); - lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); - lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); - lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); - - lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); - lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); - lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); - lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); - - lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); - lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); - lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); - lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); - } - workgroupBarrier(); - } - - let a_global = a_global_base + base_A + a_idx; - let b_global = b_global_base + base_B; - let output_idx = ((a_global) * uniforms.N + b_global)/4; - // This creates a shader requirement that uniforms.N % 16 == 0 - if (a_global < uniforms.M && b_global < uniforms.N) - { - output[output_idx] = lane_output1; - output[output_idx+1] = lane_output2; - output[output_idx+2] = lane_output3; - output[output_idx+3] = lane_output4; - } -)MAIN_FN"; - - return Status::OK(); -} - Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -796,16 +548,16 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context TensorShape b_shape({N_, K_}); ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); auto* y = context.Output(0, helper.OutputShape()); - const uint32_t data_size = gsl::narrow(y->Shape().Size()); + const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); if (data_size == 0) { return Status::OK(); } - const uint32_t batch_count = gsl::narrow(helper.OutputOffsets().size()); - const uint32_t M = gsl::narrow(helper.M()); - const uint32_t N = gsl::narrow(helper.N()); - const uint32_t K = gsl::narrow(helper.K()); - const uint32_t block_size = gsl::narrow(block_size_); + const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t N = onnxruntime::narrow(helper.N()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_); constexpr uint32_t nbits = 4; const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; @@ -822,56 +574,17 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); } - const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); - // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. - // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 - const bool use_dp4a = has_subgroup && context.AdapterInfo().backendType != wgpu::BackendType::Metal; - if (accuracy_level_ == 4 && block_size == 32 && - batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 && - !has_zero_points && use_dp4a && M >= kMinMForTileOptimization) { - constexpr uint32_t kVec4Components = 4; - constexpr uint32_t kVec2Components = 2; - constexpr uint32_t kU32Components = 4; - - constexpr uint32_t kBlockSizeA = 128; - DP4AMatMulQuantizeProgram quantize_program; - quantize_program.SetWorkgroupSize(1); - quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); - TensorShape a_quant_shape{1, M, K / kU32Components}; - Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); - TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); - Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); - quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}}) - .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow(1)}, - {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}) - .AddUniformVariable({static_cast(M * K / kVec4Components)}); - ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); - - constexpr uint32_t kTileSize = 64; - TensorShape reshaped_y_shape{1, M, N / kVec4Components}; - DP4AMatMulNBitsProgram mul_program; - mul_program.SetWorkgroupSize(256); - mul_program.SetDispatchGroupSize( - (M + kTileSize - 1) / kTileSize, - (N + kTileSize - 1) / kTileSize, 1); - mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, - {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) - .AddUniformVariables({{static_cast(M)}, - {static_cast(N)}, - {static_cast(K)}, - {static_cast(K / 8)}, - {static_cast(K / 16)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}); - return context.RunProgram(mul_program); + if (M >= kMinMForTileOptimization && + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { + return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, context, y); } // TODO: Support output_number > 1. Some cases are failed when output_number > 1. constexpr uint32_t output_number = 1; const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; + const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); const bool use_subgroup = has_subgroup && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32; - MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow(components_b), has_zero_points, use_subgroup}; + MatMulNBitsProgram program{output_number, block_size, tile_m, static_cast(components_b), has_zero_points, use_subgroup}; if (M > kMinMForTileOptimization && block_size == 32) { components = 1; constexpr uint32_t workgroup_size = 64; @@ -884,7 +597,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context program.CacheHint("T_M" + std::to_string(tile_m) + "Subgroup" + std::to_string(use_subgroup)); } else if (block_size == 32) { components = 1; - constexpr uint32_t workgroup_size = 64; + // TODO: Tune the workgroup size when `M=1`. + constexpr uint32_t workgroup_size = 128; const uint32_t workgroup_y = N % 8 == 0 ? 8 : 1; const uint32_t workgroup_x = workgroup_size / workgroup_y; program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); @@ -900,10 +614,10 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context TensorShape reshaped_y_shape{batch_count, M, N / components}; program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, static_cast(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, static_cast(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, {scales, ProgramTensorMetadataDependency::None}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(components)}) .AddUniformVariable({block_size}); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 3d72629bf6b25..10221e19c7400 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -35,25 +35,6 @@ class MatMulNBitsProgram final : public Program { bool use_subgroup_; }; -class DP4AMatMulQuantizeProgram final : public Program { - public: - DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}); -}; - -class DP4AMatMulNBitsProgram final : public Program { - public: - DP4AMatMulNBitsProgram() : Program{"DP4AMatMulNBits"} {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"M", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"K8", ProgramUniformVariableDataType::Uint32}, - {"K16", ProgramUniformVariableDataType::Uint32}); -}; - class MatMulNBits final : public WebGpuKernel { public: MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 2944a4d61b8ef..cb024d2a758a9 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -185,13 +185,13 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te mul_program.SetDispatchGroupSize( (N + kTileSizeB - 1) / kTileSizeB, (M + kTileSizeA - 1) / kTileSizeA, 1); - mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kU32Components)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) + mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({{static_cast(M)}, {static_cast(N)}, {static_cast(K)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, gsl::narrow(1)}); + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1}); return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 2e7ed5a16a2f0..068a94c7390e2 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -37,8 +37,8 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/compute_capability.h b/onnxruntime/core/framework/compute_capability.h index 5f21ba2f013e0..819264b3960e7 100644 --- a/onnxruntime/core/framework/compute_capability.h +++ b/onnxruntime/core/framework/compute_capability.h @@ -2,8 +2,11 @@ // Licensed under the MIT License. #pragma once +#include #include "core/common/common.h" #include "core/graph/indexed_sub_graph.h" +#include "core/graph/graph.h" +#include "core/optimizer/graph_optimizer_registry.h" namespace onnxruntime { // A structure encodes a subgraph and the method to run it. @@ -21,5 +24,22 @@ struct ComputeCapability { ComputeCapability(std::unique_ptr t_sub_graph) : sub_graph(std::move(t_sub_graph)) {} + + // Optional function to optimize this ComputeCapability. + // This will be called by ORT once the ComputeCapability is assigned to the EP. + std::function + optimization_func; + + // Optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized. + // when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made. + // IndexedSubGraph.nodes: + // - update based on RemovedNode/AddNode calls + // IndexedSubGraph.MetaDef (if present): + // - inputs and outputs will be unchanged + // - constant_initializers MAY change if we constant fold an initializer during optimization + std::vector> nodes_to_optimize; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index 3a937a119d03b..df85daa006a43 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -14,6 +14,7 @@ namespace onnxruntime { std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry&, IResourceAccountant*) const { std::vector> result; for (const auto& node : graph.Nodes()) { diff --git a/onnxruntime/core/framework/external_data_loader.cc b/onnxruntime/core/framework/external_data_loader.cc index fe73a55735631..c577805e69cc4 100644 --- a/onnxruntime/core/framework/external_data_loader.cc +++ b/onnxruntime/core/framework/external_data_loader.cc @@ -60,7 +60,12 @@ common::Status LoadWebAssemblyExternalData(const Env& env, break; case 1: // Load external data to GPU. - Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); + // TODO: use a unified interface for upload external buffer. + if (Module.webgpuUploadExternalBuffer) { + Module.webgpuUploadExternalBuffer(dataIdOrBuffer, data); + } else { + Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); + } break; default: return 4; // Unknown error occurred in memory copy. diff --git a/onnxruntime/core/framework/external_data_loader.h b/onnxruntime/core/framework/external_data_loader.h index 117da7d0a4afa..90d48ca800797 100644 --- a/onnxruntime/core/framework/external_data_loader.h +++ b/onnxruntime/core/framework/external_data_loader.h @@ -42,7 +42,7 @@ class IExternalDataLoader { enum class ExternalDataLoadType { CPU = 0, -#if defined(USE_JSEP) +#if defined(USE_JSEP) || defined(USE_WEBGPU) WEBGPU_BUFFER = 1, #endif }; diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index 1eb7420b44d2c..d3e435c0341b0 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include "core/framework/fallback_cpu_capability.h" #include "core/common/inlined_containers.h" @@ -176,3 +178,5 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe } } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h index bca75adbfd5a7..ddcc1de96d2af 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.h +++ b/onnxruntime/core/framework/fallback_cpu_capability.h @@ -3,6 +3,8 @@ #pragma once +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + #include #include "core/common/inlined_containers_fwd.h" #include "core/framework/execution_provider.h" // for IExecutionProvider::IKernelLookup @@ -26,3 +28,5 @@ std::unordered_set GetCpuPreferredNodes(const GraphViewer& graph, const logging::Logger& logger); } // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 111f8e0a5fc34..ff4d300f665b1 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -142,13 +142,15 @@ struct GetCapabilityForEPParams { std::reference_wrapper debug_graph_fn; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) IResourceAccountant* resource_accountant; + std::reference_wrapper graph_optimizer_registry; }; auto get_capabilities = [](const IExecutionProvider& ep, const GraphViewer& graph_viewer, const IExecutionProvider::IKernelLookup& kernel_lookup, - IResourceAccountant* resource_accountant) { - auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup, resource_accountant); + IResourceAccountant* resource_accountant, + const GraphOptimizerRegistry& graph_optimizer_registry) { + auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, resource_accountant); // In theory an EP could return an empty capability. Remove those. capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), @@ -182,10 +184,11 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l auto& graph = params.graph.get(); auto& capabilities = params.capabilities.get(); + const auto& graph_optimizer_registry = params.graph_optimizer_registry.get(); { const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); if (capabilities.empty()) { return Status::OK(); @@ -223,7 +226,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l capabilities.clear(); const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); // all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities InlinedHashSet new_nodes_in_capabilities; @@ -261,6 +264,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, const KernelRegistryManager& kernel_registry_mgr, const IExecutionProvider& current_ep, + const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger, std::vector>& capabilities) { const auto& ep_type = current_ep.Type(); @@ -272,14 +276,62 @@ static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, logger}; // TODO: Provide EP with a capability to look inside the functions. - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, nullptr); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, nullptr, graph_optimizer_registry); return Status::OK(); } /** - * Check if a node can be placed on a specific provider. - * Do nothing if the node is already assigned + * Check whether the given IndexedSubGraph is available for assigning to a specific provider. + * + */ +static bool IsIndexedSubGraphAvailableForAssignment(Graph& graph, + const IndexedSubGraph& capability, + GraphPartitioner::Mode mode, + const std::string& provider_type) { + // The provider can run a single node in the if not using meta-defs. + if (capability.GetMetaDef() == nullptr && capability.nodes.size() == 1) { + auto* node = graph.GetNode(capability.nodes[0]); + if (nullptr != node && node->GetExecutionProviderType().empty()) { + // The node was not fused or assigned. + return true; + } + return false; + } + + // if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned, + // so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by + // preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to + // and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes + // should never be on those lists. + // + // when the ORT format model is loaded we will process it normally with EP priority being applied for + // whichever EPs are enabled at the time. + // + // e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP. + // We want the ORT format model to be able to be run as efficiently as possible on either platform, + // so we want all the nodes that either may take to be preserved. If we did not do this we would + // need to create one ORT format model for Android and one for iOS. + if (mode == GraphPartitioner::Mode::kAssignOnly) { + return true; + } + + for (auto node_index : capability.nodes) { + const auto* node = graph.GetNode(node_index); + if ((nullptr == node) || + (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) { + // The node was fused or assigned, so that the whole sub-graph will not be assigned to this + // The assumption is that this can only run the sub-graph as a whole unit. + return false; + } + } + + return true; +} + +/** + * Return a fused node or assign the nodes in the indexed subgraph to the current EP. + * * \param graph * \param capability * \param kernel_registry_mgr @@ -298,75 +350,42 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, if (nullptr == capability.GetMetaDef()) { TryAssignSingleNode(graph, capability, provider_type); } else { - // The can run a fused in the . + const bool acc_enabled = capability.IsAccountingEnabled(); + if (mode == GraphPartitioner::Mode::kNormal) { + std::ostringstream oss; + oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++; + std::string node_name = oss.str(); - // Check whether any node in the was already assigned. If so it cannot be stolen as assignment is done - // in order of EP priority - bool sub_graph_available_for_assignment = true; - if (mode != GraphPartitioner::Mode::kAssignOnly) { - // if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned, - // so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by - // preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to - // and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes - // should never be on those lists. - // - // when the ORT format model is loaded we will process it normally with EP priority being applied for - // whichever EPs are enabled at the time. - // - // e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP. - // We want the ORT format model to be able to be run as efficiently as possible on either platform, - // so we want all the nodes that either may take to be preserved. If we did not do this we would - // need to create one ORT format model for Android and one for iOS. - for (auto node_index : capability.nodes) { - const auto* node = graph.GetNode(node_index); - if ((nullptr == node) || - (!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) { - // The node was fused or assigned, so that the whole sub-graph will not be assigned to this - // The assumption is that this can only run the sub-graph as a whole unit. - sub_graph_available_for_assignment = false; - break; - } + Node* fused_node = nullptr; + if (fusion_style == IExecutionProvider::FusionStyle::Function) { + fused_node = &graph.FuseSubGraph(capability, node_name); + } else { + // create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed + // through to Compile via a filtered GraphViewer. + fused_node = &graph.BeginFuseSubGraph(capability, node_name); } - } - if (sub_graph_available_for_assignment) { - const bool acc_enabled = capability.IsAccountingEnabled(); - if (mode == GraphPartitioner::Mode::kNormal) { - std::ostringstream oss; - oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++; - std::string node_name = oss.str(); - - Node* fused_node = nullptr; - if (fusion_style == IExecutionProvider::FusionStyle::Function) { - fused_node = &graph.FuseSubGraph(capability, node_name); - } else { - // create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed - // through to Compile via a filtered GraphViewer. - fused_node = &graph.BeginFuseSubGraph(capability, node_name); - } - - fused_node->SetExecutionProviderType(provider_type); - if (acc_enabled) { - // We account for the fused node. We operate under assumption - // that the fused node would use no more memory when the nodes we are fusing. - // and potentially less than that, and therefore, no threshold check is needed here. - // All threshold checks are done within the EP. - capability.ComputeAndAccountForNode(*fused_node); - } + fused_node->SetExecutionProviderType(provider_type); + if (acc_enabled) { + // We account for the fused node. We operate under assumption + // that the fused node would use no more memory when the nodes we are fusing. + // and potentially less than that, and therefore, no threshold check is needed here. + // All threshold checks are done within the EP. + capability.ComputeAndAccountForNode(*fused_node); + } - result = fused_node; - } else { - // assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them. - // This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion - // at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device - // capabilities. - for (size_t i = 0, limit = capability.nodes.size(); i < limit; ++i) { - auto* node = graph.GetNode(capability.nodes[i]); - if (node != nullptr) { - node->SetExecutionProviderType(provider_type); - if (acc_enabled) { - capability.AccountForNode(i); - } + result = fused_node; + } else { + // assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them. + // This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion + // at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device + // capabilities. + for (size_t i = 0, limit = capability.nodes.size(); i < limit; ++i) { + auto* node = graph.GetNode(capability.nodes[i]); + if (node != nullptr) { + node->SetExecutionProviderType(provider_type); + if (acc_enabled) { + capability.AccountForNode(i); } } } @@ -386,7 +405,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, const layout_transformation::DebugGraphFn& debug_graph_fn, - const logging::Logger& logger, IResourceAccountant* resource_accountant) { + const logging::Logger& logger, IResourceAccountant* resource_accountant, + const GraphOptimizerRegistry& graph_optimizer_registry) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -400,7 +420,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn, logger, resource_accountant)); + transform_layout_fn, debug_graph_fn, logger, resource_accountant, graph_optimizer_registry)); } } @@ -424,7 +444,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, mode, std::cref(transform_layout_fn), std::cref(debug_graph_fn), - resource_accountant}; + resource_accountant, + std::ref(graph_optimizer_registry)}; ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { @@ -450,7 +471,30 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, entry->sub_graph->GetMetaDef() != nullptr; })); for (auto& capability : capabilities) { - Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); + // The can run a fused in the . + // Check whether any node in the was already assigned. If so it cannot be stolen as assignment is done + // in order of EP priority + bool sub_graph_available_for_assignment = IsIndexedSubGraphAvailableForAssignment(graph, *capability->sub_graph, mode, type); + + // If the is available to be assigned to the EP and the ComputeCapability has nodes_to_optimize, + // run EP related optimizations and update ComputeCapability. + if (sub_graph_available_for_assignment && !capability->nodes_to_optimize.empty()) { + for (auto& optimization_cc : capability->nodes_to_optimize) { + if (optimization_cc->optimization_func) { + auto status = optimization_cc->optimization_func(graph, *optimization_cc, *capability, graph_optimizer_registry); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, "The optimization function failed to finish."); + } + // #TODO: Handle nested optimization ComputeCapability + } + } + } + + Node* n = nullptr; + if (sub_graph_available_for_assignment) { + n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); + } + if (n != nullptr) { // searching in kernel registries, if no kernel registered for the fused_node, use compile approach if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) { @@ -587,6 +631,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, const KernelRegistryManager& kernel_registry_mgr, Graph& graph, + const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger, InlinedHashSet& not_inlined, size_t& inlined_count) { @@ -603,6 +648,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_mgr, *subgraph, + graph_optimizer_registry, logger, not_inlined, inlined_count)); @@ -627,7 +673,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger, + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, graph_optimizer_registry, logger, capabilities)); for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; @@ -667,23 +713,28 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide } // Validate the ep_context_path to make sure it is file path and check whether the file exist already -static Status EpContextFilePathCheck(const std::string& ep_context_path, - const std::filesystem::path& model_path) { - std::filesystem::path context_cache_path; +static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_path, + const std::filesystem::path& model_path, + std::filesystem::path& context_cache_path) { if (!ep_context_path.empty()) { context_cache_path = ep_context_path; if (!context_cache_path.has_filename()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder."); } } else if (!model_path.empty()) { - context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); + auto pos = model_path.native().find_last_of(ORT_TSTR(".")); + if (pos != std::string::npos) { + context_cache_path = model_path.native().substr(0, pos) + ORT_TSTR("_ctx.onnx"); + } else { + context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); + } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty."); } if (std::filesystem::exists(context_cache_path)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", - context_cache_path, "' exist already."); + context_cache_path, "' exist already. Please remove the EP context model if you want to re-generate it."); } return Status::OK(); @@ -714,15 +765,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers }; std::filesystem::path context_cache_path; - const std::filesystem::path& model_path = graph.ModelPath(); - - if (!ep_context_path.empty()) { - context_cache_path = ep_context_path; - } else if (!model_path.empty()) { - context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty"); - } + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_path, graph.ModelPath(), context_cache_path)); Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path @@ -794,6 +837,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager, const std::optional& acc_map, + const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger) { bool modified_graph = false; @@ -817,7 +861,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, partition_params.debug_graph_fn, - logger, resource_accountant)); + logger, resource_accountant, graph_optimizer_registry)); } // expand any nodes that have an ONNX function definition but no matching ORT kernel. @@ -838,6 +882,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params, KernelRegistryManager& kernel_registry_mgr, IExecutionProvider& current_ep, + const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability @@ -853,7 +898,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param PartitionParams subgraph_partition_params = partition_params; subgraph_partition_params.graph = std::ref(subgraph); ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, - current_ep, logger)); + current_ep, graph_optimizer_registry, logger)); } } @@ -869,7 +914,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param std::cref(partition_params.transform_layout_function), std::cref(partition_params.debug_graph_fn), #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - nullptr + nullptr, + std::ref(graph_optimizer_registry) }; // clang-format on @@ -962,10 +1008,11 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param static Status PartitionOrtFormatModel(const PartitionParams& partition_params, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager, + const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger) { // process full graph with each EP for (const auto& ep : execution_providers) { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, graph_optimizer_registry, logger)); } return Status::OK(); @@ -992,6 +1039,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_manager, graph, + *graph_optimizer_registry_, logger, not_inlined, inlined_count)); @@ -1048,8 +1096,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, std::ref(*fused_kernel_registry), std::ref(fused_node_unique_id), std::cref(transform_layout_function), - std::cref(debug_graph_fn), - }; + std::cref(debug_graph_fn)}; #else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1068,7 +1115,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (ep_context_enabled) { std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); // Check before EP compile graphs - ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath())); + std::filesystem::path context_cache_path; + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_path, graph.ModelPath(), context_cache_path)); } // We use this only if Resource Aware Partitioning is enabled for any of the EPs @@ -1077,7 +1125,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, ORT_RETURN_IF_ERROR(NodeStatsRecorder::CreateAccountants(config_options, graph.ModelPath(), ep_acc_map)); ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, - ep_acc_map, logger)); + ep_acc_map, *graph_optimizer_registry_, logger)); if (ep_context_enabled) { std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); @@ -1091,7 +1139,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, *graph_optimizer_registry_, logger)); } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index d1ef193cf1520..b9d4022cb5a14 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -7,6 +7,7 @@ #include "core/graph/graph.h" #include "core/framework/fuse_nodes_funcs.h" #include "core/framework/transform_layout_functions.h" +#include "core/optimizer/graph_optimizer_registry.h" namespace onnxruntime { @@ -24,9 +25,12 @@ class GraphPartitioner { }; // The order of providers represents the user preference. - GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers) + GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, + const ExecutionProviders& providers, + std::unique_ptr graph_optimizer_registry) : kernel_registry_mgr_(kernel_registry_mgr), - providers_(providers) { + providers_(providers), + graph_optimizer_registry_(std::move(graph_optimizer_registry)) { } // Run partitioning. @@ -64,6 +68,7 @@ class GraphPartitioner { KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; + std::unique_ptr graph_optimizer_registry_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index a884927abddb7..1c446840b7938 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -10,8 +10,8 @@ #include "core/framework/sparse_tensor.h" #include "core/graph/onnx_protobuf.h" #include "core/session/ort_apis.h" +#include "core/session/model_editor_api.h" #include "core/framework/error_code_helper.h" - #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_map_type_info.h" #include "core/framework/onnxruntime_sequence_type_info.h" @@ -40,7 +40,7 @@ OrtTypeInfo::OrtTypeInfo(std::unique_ptr optional_type_info : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info)) {} OrtTypeInfo::OrtTypeInfo(ONNXType type, std::unique_ptr data) noexcept - : type(type), data(std::move(data)) { + : type(type), tensor_type_info(std::move(data)) { } OrtTypeInfo::~OrtTypeInfo() = default; @@ -55,7 +55,9 @@ ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeI ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input, _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN - *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data.get() : nullptr; + *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) + ? input->tensor_type_info.get() + : nullptr; return nullptr; API_IMPL_END } @@ -84,8 +86,8 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeI API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const char** const out, - _Out_ size_t* len) { +ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, + _Out_ const char** const out, _Out_ size_t* len) { API_IMPL_BEGIN *out = type_info->denotation.c_str(); *len = type_info->denotation.size(); @@ -93,6 +95,61 @@ ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* API_IMPL_END } +#if !defined(ORT_MINIMAL_BUILD) +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_TENSOR); + ti->tensor_type_info = tensor_info->Clone(); + *type_info = ti.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_SPARSETENSOR); + ti->tensor_type_info = tensor_info->Clone(); + *type_info = ti.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, + _In_ const OrtTypeInfo* map_value_type, _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_MAP); + ti->map_type_info = std::make_unique(map_key_type, map_value_type->Clone()); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_SEQUENCE); + ti->sequence_type_info = std::make_unique(sequence_type->Clone()); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_OPTIONAL); + ti->optional_type_info = std::make_unique(contained_type->Clone()); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} +#endif // !defined(ORT_MINIMAL_BUILD) + ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { std::unique_ptr p(ptr); } @@ -298,8 +355,8 @@ std::unique_ptr OrtTypeInfo::Clone() const { #endif case ONNX_TYPE_TENSOR: { std::unique_ptr info; - if (data) { - info = data->Clone(); + if (tensor_type_info) { + info = tensor_type_info->Clone(); } result = MakePtr(type, std::move(info)); result->denotation = denotation; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 72d263d5fa442..54bb946e0d36b 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -31,7 +31,7 @@ struct OrtTypeInfo { ONNXType type; std::string denotation; - std::unique_ptr data; + std::unique_ptr tensor_type_info; std::unique_ptr map_type_info; std::unique_ptr sequence_type_info; std::unique_ptr optional_type_info; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 83a353615bc35..9d45ec38e5a32 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -81,6 +81,11 @@ static common::Status ExtDataTensorProtoToTensor(const Env& env, ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, ext_data_buf, ext_data_len, ext_data_deleter, buffered_tensor, &prepacked_for_graph)); + if constexpr (endian::native != endian::little) { + if (!proto_path.empty() && (proto_path.compare(onnxruntime::utils::kTensorProtoMemoryAddressTag) != 0)) { + utils::ConvertRawDataInTensorProto(const_cast(&tensor_proto), ext_data_buf, ext_data_len); + } + } // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be // avoided if the Tensor class implements the do-nothing behavior when given a @@ -203,13 +208,12 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } } -common::Status AllocateTensor( - const onnxruntime::MemBuffer* m, - std::unique_ptr& p_tensor, - const onnxruntime::DataTypeImpl* const& type, - onnxruntime::TensorShape& tensor_shape, - bool use_device_allocator_for_initializers, - const onnxruntime::AllocatorPtr& alloc) { +common::Status AllocateTensor(const onnxruntime::MemBuffer* m, + std::unique_ptr& p_tensor, + const onnxruntime::DataTypeImpl* const& type, + onnxruntime::TensorShape& tensor_shape, + bool use_device_allocator_for_initializers, + const onnxruntime::AllocatorPtr& alloc) { if (m != nullptr) { p_tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); if (m->GetLen() < p_tensor->SizeInBytes()) { @@ -354,6 +358,7 @@ common::Status SaveInitializedTensors( } ORT_RETURN_IF_ERROR(planner.Trace(entry.first, entry.second)); } + // 2. allocate weight buffer on different locations // planned_initializers_memory_size_in_byte is not actual physical size. // It's the virtual size computed by planner. @@ -386,6 +391,9 @@ common::Status SaveInitializedTensors( if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { ort_value = *(session_options.initializers_to_share_map.at(name)); LOGS(logger, INFO) << "Using user supplied initializer with name (" << name << ")."; + + } else if (graph.GetOrtValueInitializer(name, ort_value)) { + // populated OrtValue from the Graph instance } else { const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second); @@ -397,10 +405,9 @@ common::Status SaveInitializedTensors( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; Tensor* p_tensor = nullptr; - if (auto iter = buffered_tensors.find(name); - iter != buffered_tensors.end()) { - p_tensor = iter->second.release(); - buffered_tensors.erase(iter); + auto buffered_tensors_iter = buffered_tensors.find(name); + if (buffered_tensors_iter != buffered_tensors.end()) { + p_tensor = buffered_tensors_iter->second.get(); } Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, @@ -412,6 +419,12 @@ common::Status SaveInitializedTensors( oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); return Status(st.Category(), st.Code(), oss.str()); } + + if (p_tensor != nullptr) { + // p_tensor was wrapped in a deleter by DeserializeTensorProto so we can simply release it here. + ORT_IGNORE_RETURN_VALUE(buffered_tensors_iter->second.release()); + buffered_tensors.erase(buffered_tensors_iter); + } } // 'name' is a reference to a string within the TensorProto that save_tensor_func may free diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 418e46924fb9f..9bbea279da82d 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -49,10 +49,27 @@ ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShape API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* this_ptr, +ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) { API_IMPL_BEGIN - this_ptr->shape = onnxruntime::TensorShape(dim_values, dim_count); + if (std::any_of(dim_values, dim_values + dim_count, [](int64_t v) { return v < -1; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "dim_values must be -1 (symbolic dimension) or larger."); + } + + auto num_dims = std::max(dim_count, info->dim_params.size()); + + // make shape and dim_values consistent + info->dim_params.resize(num_dims, ""); + + onnxruntime::TensorShapeVector dims; + dims.resize(num_dims, -1); + + for (size_t idx = 0; idx < dim_count; ++idx) { + dims[idx] = dim_values[idx]; + } + + info->shape = onnxruntime::TensorShape(dims); + return nullptr; API_IMPL_END } @@ -88,10 +105,22 @@ ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, ORT_API_STATUS_IMPL(OrtApis::SetSymbolicDimensions, _In_ struct OrtTensorTypeAndShapeInfo* info, _In_ const char** names, _In_ size_t dim_params_length) { + auto num_dims = std::max(info->shape.NumDimensions(), dim_params_length); + + // make shape and dim_values consistent + if (num_dims > info->shape.NumDimensions()) { + auto dim_values = info->shape.AsShapeVector(); + dim_values.resize(num_dims, -1); + info->shape = onnxruntime::TensorShape(dim_values); + } + info->dim_params.clear(); + info->dim_params.resize(num_dims, ""); + for (size_t idx = 0; idx < dim_params_length; ++idx) { - info->dim_params.push_back(names[idx]); + info->dim_params[idx] = names[idx]; } + return nullptr; } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 17c37b8882168..94a2a6677358e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -270,10 +270,15 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::str tensor_proto.set_raw_data(std::move(param)); } -void ConvertRawDataInTensorProto(TensorProto* tensor) { +void ConvertRawDataInTensorProto(TensorProto* tensor, + void* ext_data_buf, + size_t ext_data_len) { size_t element_size = 1; char* bytes = NULL; size_t num_elements = 0; + if (ext_data_buf && !ext_data_len) { + return; + } switch (tensor->data_type()) { case TensorProto_DataType_FLOAT: bytes = reinterpret_cast(tensor->mutable_float_data()->mutable_data()); @@ -337,6 +342,15 @@ void ConvertRawDataInTensorProto(TensorProto* tensor) { num_elements = (tensor->raw_data().size()) / element_size; bytes = const_cast(tensor->mutable_raw_data()->c_str()); } + + if (element_size == 1) { + return; + } + if (ext_data_buf) { + ORT_ENFORCE(ext_data_len % element_size == 0); + num_elements = ext_data_len / element_size; + bytes = reinterpret_cast(ext_data_buf); + } for (size_t i = 0; i < num_elements; ++i) { char* start_byte = bytes + i * element_size; char* end_byte = start_byte + element_size - 1; @@ -1317,22 +1331,15 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, const auto* raw_data = tensor.DataRaw(); ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); - tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. // use intptr_t as OFFSET_TYPE is signed. in theory you could get a weird looking value if the address uses the // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. auto offset = narrow(reinterpret_cast(raw_data)); - ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("location"); - entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); - entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("offset"); - entry->set_value(std::to_string(offset)); - entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("length"); - entry->set_value(std::to_string(tensor.SizeInBytes())); + ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, + offset, tensor.SizeInBytes(), tensor_proto); + } else { utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index f5dec7ae988f2..79eae48c10411 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -41,12 +41,18 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, ExternalDataInfo::PrepackedInfos* prepacked_infos = nullptr); /** * This function is used to convert the endianess of Tensor data. + * If ext_data_buf is provided, then this buffer content's endianess + * will be changed. * Mostly, will be used in big endian system to support the model file * generated on little endian system. - * @param initializer given initializer tensor + * @param tensor_proto given initializer tensor + * @param ext_data_buf optional externl data buffer + * @param ext_data_len optional externl data buffer lengeh * @returns None */ -void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto* initializer); +void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto* tensor_proto, + void* ext_data_buf = NULL, + size_t ext_data_len = 0); /** * Wrapper function for set_raw_data. diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e4915616b7b7c..39ffc6a5b0cee 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -7,30 +7,34 @@ #include #include #include -#include #include +#include -#include "core/common/common.h" #include + +#include "core/common/common.h" #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" +#include "core/framework/tensor_type_and_shape.h" #include "core/flatbuffers/schema/ort.fbs.h" -#include "core/framework/tensor_shape.h" #include "core/framework/tensor_external_data_info.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor_type_and_shape.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" +#include "core/graph/function_utils.h" #include "core/graph/graph_flatbuffers_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/model.h" +#include "core/graph/model_editor_api_types.h" #include "core/graph/model_load_utils.h" #include "core/graph/model_saving_options.h" #include "core/graph/node_attr_utils.h" #include "core/graph/op.h" #include "core/graph/runtime_optimization_record_container.h" -#include "core/graph/function_utils.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/graph/function.h" @@ -3500,6 +3504,10 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { #if !defined(DISABLE_SPARSE_TENSORS) sparse_tensor_names_.erase(tensor_name); #endif + + // doesn't matter if it existed or not + ORT_IGNORE_RETURN_VALUE(ortvalue_initializers_.erase(tensor_name)); + SetGraphResolveNeeded(); } else { #if !defined(DISABLE_SPARSE_TENSORS) @@ -3631,8 +3639,8 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( return Status::OK(); } -#endif // DISABLE_EXTERNAL_INITIALIZERS +#endif // DISABLE_EXTERNAL_INITIALIZERS #endif // !defined(ORT_MINIMAL_BUILD) bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto*& value) const { @@ -3645,6 +3653,16 @@ bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorPro return true; } +bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value) const { + auto it = ortvalue_initializers_.find(name); + if (it == ortvalue_initializers_.end()) { + return false; + } + + value = it->second; + return true; +} + void Graph::CleanAllInitializedTensors() noexcept { name_to_initial_tensor_.clear(); #if !defined(DISABLE_SPARSE_TENSORS) @@ -3660,6 +3678,8 @@ void Graph::CleanAllInitializedTensors() noexcept { delete graph_proto_->mutable_initializer()->ReleaseCleared(); } #endif + + ortvalue_initializers_.clear(); } const ONNX_NAMESPACE::TensorProto* Graph::GetConstantInitializer(const std::string& initializer_name, @@ -3709,13 +3729,14 @@ void Graph::AddValueInfo(const NodeArg* new_value_info) { value_info_.insert(new_value_info); } -std::vector Graph::CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, +template +std::vector Graph::CreateNodeArgs(const StringRange& names, const ArgNameToTypeMap& name_to_type_map) { const auto name_to_type_map_end = name_to_type_map.end(); std::vector results; results.reserve(names.size()); - for (auto& name : names) { + for (const std::string& name : names) { const TypeProto* type = nullptr; auto name_to_type_iter = name_to_type_map.find(name); @@ -4076,27 +4097,51 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { // This is used for constructing full path for external data // if it exists + auto add_initializer = [](TensorList& output_initializers, const TensorProto& initializer) -> void { + TensorProto& output = *output_initializers.Add(); + output = initializer; + + // copy any in-memory external data into raw data + if (utils::HasExternalData(initializer)) { + const std::filesystem::path ignored; + std::basic_string location; + onnxruntime::FileOffsetType file_offset; + SafeInt tensor_byte_size; + + ORT_THROW_IF_ERROR(utils::GetExternalDataInfo(initializer, ignored, location, file_offset, tensor_byte_size)); + + if (location == onnxruntime::utils::kTensorProtoMemoryAddressTag) { + // file_offset is address + void* data = reinterpret_cast(file_offset); + + // set in raw data + output.clear_data_location(); + output.set_raw_data(data, tensor_byte_size); + } + } + }; + + auto* mutable_initializers = result.mutable_initializer(); + #if !defined(DISABLE_SPARSE_TENSORS) const auto& model_path = ModelPath(); // We want to make sure that sparse initializers do not appear // as dense duplicates within the initializers list. - if (!sparse_tensor_names_.empty()) { - const auto sparse_end = sparse_tensor_names_.end(); - auto* mutable_initializer = result.mutable_initializer(); - for (const auto& initializer : graph_proto_->initializer()) { - if (sparse_end == sparse_tensor_names_.find(initializer.name())) { - *mutable_initializer->Add() = initializer; - } else { - auto& sparse_initializer = *result.add_sparse_initializer(); - auto status = utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer); - ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); - } + const bool has_sparse_initializers = !sparse_tensor_names_.empty(); + const auto sparse_end = sparse_tensor_names_.end(); + for (const auto& initializer : graph_proto_->initializer()) { + if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(initializer.name())) { + add_initializer(*mutable_initializers, initializer); + } else { + auto& sparse_initializer = *result.add_sparse_initializer(); + auto status = utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer); + ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); } - } else { - *result.mutable_initializer() = graph_proto_->initializer(); } #else - *result.mutable_initializer() = graph_proto_->initializer(); + for (const auto& initializer : graph_proto_->initializer()) { + add_initializer(*mutable_initializers, initializer); + } #endif return result; @@ -5345,6 +5390,9 @@ Status Graph::InlineFunction(Node& callnode) { } void Graph::SetInputs(gsl::span inputs) { + graph_inputs_including_initializers_.clear(); + graph_inputs_excluding_initializers_.clear(); + // creating graph from scratch // rely on SetGraphInputsOutputs() to fix up graph_inputs_excluding_initializers_ // if is_loaded_from_model_file_ == false @@ -5353,7 +5401,6 @@ void Graph::SetInputs(gsl::span inputs) { if (is_loaded_from_model_file_) { // graph loaded from model file - graph_inputs_excluding_initializers_.clear(); for (const auto* input : inputs) { ORT_ENFORCE(input->Exists(), "Input to set must exist."); if (name_to_initial_tensor_.find(input->Name()) == name_to_initial_tensor_.end()) { @@ -5370,6 +5417,7 @@ void Graph::SetInputs(gsl::span inputs) { } void Graph::SetOutputs(gsl::span outputs) { + graph_outputs_.clear(); graph_outputs_.reserve(outputs.size()); graph_outputs_.assign(outputs.begin(), outputs.end()); @@ -5688,4 +5736,207 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph return Status::OK(); } +#if !defined(ORT_MINIMAL_BUILD) +namespace { +ValueInfoProto OrtValueInfoToOnnx(const OrtValueInfo& vi) { + // the model builder API checks that the OrtValueInfo has a complete and valid OrtTypeInfo instance and that the + // name is not null/empty. + ORT_ENFORCE(vi.type_info->type == ONNX_TYPE_TENSOR, + "Internal error. Model Editor API should only allow OrtValueInfo for tensor to be created."); + + ValueInfoProto value_info_proto; + value_info_proto.set_name(vi.name); + + auto* tensor = value_info_proto.mutable_type()->mutable_tensor_type(); + const OrtTensorTypeAndShapeInfo& tensor_info = *vi.type_info->tensor_type_info.get(); + tensor->set_elem_type(tensor_info.type); + + auto& shape = *tensor->mutable_shape(); + + size_t idx = 0; + for (auto dim : tensor_info.shape.GetDims()) { + auto& dim_proto = *shape.add_dim(); + if (dim >= 0) { + dim_proto.set_dim_value(dim); + } else { + const std::string& dim_param = tensor_info.dim_params[idx]; + // if empty leave the new dim_proto with neither dim_value nor dim_param set. this represents an 'unknown' dim + if (!dim_param.empty()) { + dim_proto.set_dim_param(dim_param); + } + } + } + + return value_info_proto; +} +} // namespace + +Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updating_existing_graph) { + ArgNameToTypeMap name_to_type_map; + + // NOTE: need to create NodeArgs as we go along + + // add inputs first. the shape from an input for a non-const initializer is preferred, so we want to create the + // NodeArg for the value using that + + auto add_graph_inputs_outputs = [&, this]( + const InlinedVector>& graph_inputs_or_outputs, + bool is_input) { + // when updating a model we don't require the inputs or outputs to be set if they're unchanged. + if (updating_existing_graph && graph_inputs_or_outputs.empty()) { + return; + } + + std::vector node_args; + node_args.reserve(graph_inputs_or_outputs.size()); + for (auto& ort_value_info : graph_inputs_or_outputs) { + ValueInfoProto value_info = OrtValueInfoToOnnx(*ort_value_info); + + name_to_type_map[value_info.name()] = value_info.type(); + node_args.push_back(&GetOrCreateNodeArg(value_info.name(), &value_info.type())); + } + + if (is_input) { + SetInputs(node_args); + } else { + SetOutputs(node_args); + } + }; + + auto add_initializers = [this](const std::unordered_map>& initializers, + bool is_external) { + for (auto& name_and_ortvalue : initializers) { + // convert from OrtValue to TensorProto + const std::string& name = name_and_ortvalue.first; + OrtValue& v = *name_and_ortvalue.second; + + ORT_ENFORCE(v.IsTensor(), "Initializers must be Tensors"); + const Tensor& t = v.Get(); + TensorProto& tensor_proto = *graph_proto_->add_initializer(); + + tensor_proto.set_name(name); + tensor_proto.set_data_type(t.GetElementType()); + for (auto dim : t.Shape().GetDims()) { + tensor_proto.add_dims(dim); + } + + if (is_external) { + // pre-existing memory that we don't own. avoid a copy by storing the pointer in the ExternalDataInfo + const void* data_offset = t.DataRaw(); // address of memory not offset into file + auto offset = narrow(reinterpret_cast(data_offset)); + + ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, + offset, t.SizeInBytes(), tensor_proto); + + // add OrtValue to ortvalue_initializers_ to keep it alive and to store the deleter if provided. + ortvalue_initializers_.emplace(name, std::move(v)); + } else { + tensor_proto.set_raw_data(t.DataRaw(), t.SizeInBytes()); + } + + TypeProto type_proto{TypeProtoFromTensorProto(tensor_proto)}; + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(name, &type_proto)); + + name_to_initial_tensor_.emplace(name, &tensor_proto); + } + }; + + // process graph inputs first as we want the type/shape from them to be preferred if a graph input + // has a matching initializer + add_graph_inputs_outputs(api_graph.inputs, /*input*/ true); + + // add initializers + ortvalue_initializers_.reserve(api_graph.external_initializers.size()); + add_initializers(api_graph.external_initializers, /*is_external*/ true); + add_initializers(api_graph.initializers, /*is_external*/ false); + + // add graph outputs + add_graph_inputs_outputs(api_graph.outputs, /*input*/ false); + + // add nodes + for (const auto& ort_node : api_graph.nodes) { + const OrtNode& node = *ort_node; + + // convert Constant nodes to initializers + if (node.operator_name == "Constant" && node.domain_name == kOnnxDomain) { + // graph_proto_ provides storage + TensorProto& tensor = *graph_proto_->add_initializer(); + + // create NodeProto from OrtNode so we can use the existing conversion functions + NodeProto node_proto; + + // 'Constant' node has no inputs or attributes + ORT_RETURN_IF_NOT(node.input_names.empty() && node.attributes.size() == 1 && node.output_names.size() == 1, + node.node_name, + " is an invalid 'Constant' node. " + "Must have no inputs, one attribute and one output. "); + + node_proto.add_attribute()->CopyFrom(node.attributes[0]); + node_proto.add_output(node.output_names[0]); + + node_proto.set_op_type(node.operator_name); + node_proto.set_name(node.node_name); + node_proto.set_domain(node.domain_name); + + ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(node_proto, /*model_path*/ "", tensor)); + name_to_initial_tensor_.emplace(node.output_names[0], &tensor); + + continue; + } + + auto input_defs = CreateNodeArgs(node.input_names, name_to_type_map); + auto output_defs = CreateNodeArgs(node.output_names, name_to_type_map); + + const auto num_attributes = node.attributes.size(); + + NodeAttributes attributes; + attributes.reserve(num_attributes); + + for (const auto& attr : node.attributes) { + attributes[attr.name()] = attr; + } + + ORT_IGNORE_RETURN_VALUE(AddNode(node.node_name, node.operator_name, /*doc_string*/ "", + input_defs, output_defs, &attributes, node.domain_name)); + } + + return Resolve(); +} + +// static +Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, + const Model& owning_model, + const std::unordered_map& domain_to_version, + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + bool strict_shape_type_inference, + const logging::Logger& logger, + std::unique_ptr& graph) { + graph = std::make_unique(owning_model, + domain_to_version, + schema_registry, + /*parent_graph*/ nullptr, /*parent_node*/ nullptr, + logger, + strict_shape_type_inference); + + return graph->LoadFromModelEditorApiModel(api_graph); +} + +Status Graph::UpdateUsingModelEditorApiModel(const OrtModel& api_model) { + for (auto& entry : api_model.domain_to_version) { + if (auto it = domain_to_version_.find(entry.first); it != domain_to_version_.end()) { + if (it->second != entry.second) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Domain version can not be changed for '", entry.first, + "'. Current version: ", it->second); + } + } else { + domain_to_version_.insert(entry); + } + } + + // this will replace inputs/outputs and add nodes. + return LoadFromModelEditorApiModel(*api_model.graph, /*updating_existing_graph*/ true); +} + +#endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 922759b02e75f..199aa79cc1dde 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -300,8 +300,6 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init const auto* fbs_raw_data = fbs_tensor.raw_data(); if (fbs_raw_data) { if (load_options.can_use_flatbuffer_for_initializers && fbs_raw_data->size() > 127) { - initializer.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); - static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); const void* data_offset = fbs_raw_data->Data(); // we reinterpret_cast this back to void* in tensorprotoutils.cc:GetExtDataFromTensorProto. @@ -309,15 +307,9 @@ Status LoadInitializerOrtFormat(const fbs::Tensor& fbs_tensor, TensorProto& init // high bit, but that should be unlikely in a scenario where we care about memory usage enough to use this path. auto offset = narrow(reinterpret_cast(data_offset)); - ONNX_NAMESPACE::StringStringEntryProto* entry = initializer.mutable_external_data()->Add(); - entry->set_key("location"); - entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); - entry = initializer.mutable_external_data()->Add(); - entry->set_key("offset"); - entry->set_value(std::to_string(offset)); - entry = initializer.mutable_external_data()->Add(); - entry->set_key("length"); - entry->set_value(std::to_string(fbs_raw_data->size())); + ExternalDataInfo::SetExternalLocationToProto(onnxruntime::utils::kTensorProtoMemoryAddressTag, + offset, fbs_raw_data->size(), initializer); + } else { // fbs_raw_data is uint8_t vector, so the size is byte size initializer.set_raw_data(fbs_raw_data->Data(), fbs_raw_data->size()); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index be0531e6473fb..7629e40c1b5fe 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -7,6 +7,7 @@ #include "core/flatbuffers/flatbuffers_utils.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/model.h" +#include "core/graph/model_editor_api_types.h" #include "core/graph/model_load_utils.h" #ifdef _MSC_VER @@ -738,6 +739,36 @@ Status Model::Load(int fd, const PathString& model_path, std::shared_ptr& return Status::OK(); } +// static +common::Status Model::LoadFromModelEditorApiModel(const OrtModel& model_editor_api_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const ModelOptions& options, + const logging::Logger& logger, + std::unique_ptr& model) { + model = std::make_unique(); + model->model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + // The optimizer Initializer class requires a path if external data is used, however in the Graph API usage the + // external data is pointing to pre-allocated memory and does not require a path. Set a dummy value to make it happy. + model->model_path_ = std::filesystem::path("_GRAPH_API_MODEL_"); + + auto schema_registry = std::make_shared(); + if (local_registries != nullptr) { + for (const auto& schema_collection : *local_registries) { + schema_registry->RegisterRegistry(schema_collection); + } + } + + ORT_RETURN_IF_ERROR(Graph::LoadFromModelEditorApiModel(*model_editor_api_model.graph, + *model, + model_editor_api_model.domain_to_version, + schema_registry, + options.strict_shape_type_inference, + logger, + model->graph_)); + + return Status::OK(); +} + Status Model::Save(Model& model, int p_fd) { if (p_fd < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0."); @@ -917,5 +948,4 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, #endif return Status::OK(); } - } // namespace onnxruntime diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 2d2086aef41fd..6fd94c60d6b99 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -280,6 +280,12 @@ class Model { const logging::Logger& logger, const ModelOptions& options = {}); + static common::Status LoadFromModelEditorApiModel(const OrtModel& graph_api_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const ModelOptions& options, + const logging::Logger& logger, + std::unique_ptr& model); + common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& model) const; @@ -333,7 +339,7 @@ class Model { ModelMetaData model_metadata_; // Path to model file. May be empty. - const std::filesystem::path model_path_; + std::filesystem::path model_path_; // Main graph of the model. std::unique_ptr graph_; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h new file mode 100644 index 0000000000000..d72bd13093b61 --- /dev/null +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers_fwd.h" +#include "core/framework/ort_value.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/onnx_protobuf.h" + +// ORT C interface types for OrtGraphApi can't be in a namespace. +// We need to define them here so onnxruntime::Model can be created from OrtModel. + +struct OrtValueInfo { + std::string name; + std::unique_ptr type_info; +}; + +struct OrtOpAttr { + ONNX_NAMESPACE::AttributeProto attr_proto; +}; + +struct OrtNode { + std::string operator_name; + std::string domain_name; + std::string node_name; + + // OrtOpAttr is 1:1 with ONNX_NAMESPACE::AttributeProto currently. + // https://github.com/microsoft/onnxruntime/blob/bd5a759d0cdbed6e7f611c990d4eb5457a9ecf60/onnxruntime/core/session/standalone_op_invoker.cc#L318 + onnxruntime::InlinedVector attributes; + onnxruntime::InlinedVector input_names; + onnxruntime::InlinedVector output_names; + + // FUTURE if we need control flow nodes + // std::unordered_map subgraphs; +}; + +struct OrtGraph { + onnxruntime::InlinedVector> inputs; + onnxruntime::InlinedVector> outputs; + std::unordered_map> initializers; + std::unordered_map> external_initializers; + std::vector> nodes; +}; + +struct OrtModel { + std::unique_ptr graph; + std::unordered_map domain_to_version; +}; diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e755b4bfa6364..e36eef672c1ed 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -21,7 +21,16 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider, const ConfigOptions& config_options, const InlinedHashSet& compatible_execution_providers, const InlinedHashSet& excluded_initializers) noexcept - : GraphTransformer("ConstantFolding", compatible_execution_providers), + : ConstantFolding("ConstantFolding", execution_provider, skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers) { +} + +ConstantFolding::ConstantFolding(const std::string& name, + const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& compatible_execution_providers, + const InlinedHashSet& excluded_initializers) noexcept + : GraphTransformer(name, compatible_execution_providers), skip_dequantize_linear_(skip_dequantize_linear), config_options_(config_options), excluded_initializers_(excluded_initializers), @@ -144,7 +153,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, for (NodeIndex i : order) { auto* node = graph.GetNode(i); - if (!node) { + if (!node || !AllowConstantFolding(*node)) { continue; } diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h index 14eb2a9c5f06b..29bc67d560788 100644 --- a/onnxruntime/core/optimizer/constant_folding.h +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -28,6 +28,24 @@ class ConstantFolding : public GraphTransformer { const InlinedHashSet& compatible_execution_providers = {}, const InlinedHashSet& excluded_initializers = {}) noexcept; + protected: + /** + * Same as the constructor above but with a name provided by derived class. + */ + ConstantFolding(const std::string& name, + const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& compatible_execution_providers = {}, + const InlinedHashSet& excluded_initializers = {}) noexcept; + /** + * Derived class can implement this virtual function to limit the nodes that can be constant folded. + */ + virtual bool AllowConstantFolding(const Node& node) const { + ORT_UNUSED_PARAMETER(node); + return true; + } + private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.cc b/onnxruntime/core/optimizer/graph_optimizer_registry.cc new file mode 100644 index 0000000000000..8ede372470485 --- /dev/null +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.cc @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/graph_optimizer_registry.h" +#include "core/optimizer/graph_transformer_utils.h" +#include "core/optimizer/selection_and_optimization_func.h" +#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" + +using namespace onnxruntime; +using namespace ::onnxruntime::common; + +namespace onnxruntime { +#if !defined(ORT_MINIMAL_BUILD) +GraphOptimizerRegistry::GraphOptimizerRegistry(const onnxruntime::SessionOptions* sess_options, + const onnxruntime::IExecutionProvider* cpu_ep, + const logging::Logger* logger) : session_options_(sess_options), + cpu_ep_(cpu_ep), + logger_(logger) { + auto status = CreatePredefinedSelectionFuncs(); + ORT_ENFORCE(status.IsOK(), "Could not create pre-defined selection functions. Error Message: ", + status.ErrorMessage()); +} + +Status GraphOptimizerRegistry::CreatePredefinedSelectionFuncs() { + transformer_name_to_selection_func_[kConstantFoldingDQ] = ConstantFoldingDQFuncs::Select; + + return Status::OK(); +} + +std::optional GraphOptimizerRegistry::GetSelectionFunc(std::string& name) const { + auto lookup = transformer_name_to_selection_func_.find(name); + if (lookup != transformer_name_to_selection_func_.end()) { + return transformer_name_to_selection_func_.at(name); + } + LOGS(*logger_, WARNING) << "Can't find selection function of " << name; + return std::nullopt; +} +#else +GraphOptimizerRegistry::GraphOptimizerRegistry(const onnxruntime::SessionOptions* sess_options, + const onnxruntime::IExecutionProvider* cpu_ep, + const logging::Logger* logger) : session_options_(sess_options), + cpu_ep_(cpu_ep), + logger_(logger) {} + +std::optional GraphOptimizerRegistry::GetSelectionFunc(std::string& /*name*/) const { + return std::nullopt; +} +#endif +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.h b/onnxruntime/core/optimizer/graph_optimizer_registry.h new file mode 100644 index 0000000000000..15c9287c0eac8 --- /dev/null +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/common/logging/logging.h" +#include "core/common/common.h" +#include "core/optimizer/graph_transformer.h" +#include "core/framework/execution_providers.h" +#include "core/framework/compute_capability.h" + +namespace onnxruntime { +/** + * Optimizer's selection function: Selects a set of nodes from a given graph for optimization. Additional key/value strings can be provided to configure the optimizer. + * If needed, use graph_optimizer_registry to access the session options, the CPU EP and the logger. + * + * Optimizer's optimization function: Gets the nodes in ComputeCapability from nodes_to_optimize. Use graph_optimizer_registry to access the session options, the CPU EP + * and the logger if needed to create the optimizer. Run optimization on the nodes/subgraph, and finally, update the ComputeCapability. + * + */ +using KeyValueConfig = std::unordered_map; +using SelectionFunc = std::function>(const GraphViewer&, + const KeyValueConfig&, + const GraphOptimizerRegistry& graph_optimizer_registry)>; +using OptimizationFunc = std::function; + +/** + * A registration/lookup class for re-usable optimizers for EPs. + */ +class GraphOptimizerRegistry { + public: + /** + * The constructor takes in session options, the CPU EP and a logger as these are required by some optimizers. + */ + GraphOptimizerRegistry(const onnxruntime::SessionOptions* sess_options, + const onnxruntime::IExecutionProvider* cpu_ep, + const logging::Logger* logger); + + /** + * Get optimizer selection function. If the optimizer name can't be found, return nullopt. + */ + std::optional GetSelectionFunc(std::string& name) const; + + /** + * Get CPU EP. + */ + const onnxruntime::IExecutionProvider& GetCpuEp() const { return *cpu_ep_; } + + /** + * Get Session Options. + */ + const onnxruntime::SessionOptions& GetSessionOptions() const { return *session_options_; } + + /** + * Get Logger. + */ + const logging::Logger* GetLogger() const { return logger_; } + + private: + const onnxruntime::SessionOptions* session_options_; + const onnxruntime::IExecutionProvider* cpu_ep_; + const logging::Logger* logger_; + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + InlinedHashMap transformer_name_to_selection_func_; + + /** + * Create pre-defined selection functions. + */ + Status CreatePredefinedSelectionFuncs(); +#endif +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc new file mode 100644 index 0000000000000..a2f46d6ae693c --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" +#include "core/optimizer/graph_optimizer_registry.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +ConstantFoldingDQ::ConstantFoldingDQ(const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& node_index_set, + const InlinedHashSet& compatible_execution_providers, + const InlinedHashSet& excluded_initializers) noexcept + : ConstantFolding("ConstantFoldingDQ", execution_provider, skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers), + node_index_set_(node_index_set) {} + +bool ConstantFoldingDQ::AllowConstantFolding(const Node& node) const { + if (node_index_set_.find(node.Index()) != node_index_set_.end()) { + return true; + } + return false; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h new file mode 100644 index 0000000000000..7aed87fa06adb --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/constant_folding.h" +#include "core/framework/ort_value.h" +#include +#include "core/framework/execution_provider.h" + +namespace onnxruntime { + +/** +@class ConstantFoldingDQ + +It's the derived class from ConstantFolding. +*/ +class ConstantFoldingDQ : public ConstantFolding { + public: + /*! Constant folding will not be applied to nodes that have one of initializers from excluded_initializers as input. + \param execution_provider Execution provider instance to execute constant folding. + */ + ConstantFoldingDQ(const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& node_index_set, + const InlinedHashSet& compatible_execution_providers = {}, + const InlinedHashSet& excluded_initializers = {}) noexcept; + + bool AllowConstantFolding(const Node& node) const override; + + private: + InlinedHashSet node_index_set_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.cc b/onnxruntime/core/optimizer/selection_and_optimization_func.cc new file mode 100644 index 0000000000000..151c61952a631 --- /dev/null +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.cc @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "selection_and_optimization_func.h" +#include "core/graph/graph_utils.h" +#include "core/framework/compute_capability.h" +#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" + +namespace onnxruntime { + +std::vector> ConstantFoldingDQFuncs::Select(const GraphViewer& graph_viewer, + const KeyValueConfig& /*config*/, + const GraphOptimizerRegistry& /*graph_optimizer_registry*/) { + std::vector> result; + std::unique_ptr sub_graph = std::make_unique(); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED /*priority-based topological sort*/); + InitializedTensorSet constant_inputs; + const InlinedHashSet excluded_initializers; + + // Select DequantizeLinear node where all inputs are constant + for (const auto& index : node_index) { + const auto& node = graph_viewer.GetNode(index); + if (node->OpType() != "DequantizeLinear") { + continue; + } + if (!graph_utils::AllNodeInputsAreConstant(graph_viewer.GetGraph(), *node, constant_inputs, excluded_initializers)) { + continue; + } + sub_graph->nodes.push_back(index); + } + + result.push_back(std::make_unique(std::move(sub_graph))); + result.back()->optimization_func = ConstantFoldingDQFuncs::Optimize; + return result; +} + +Status ConstantFoldingDQFuncs::Optimize(Graph& graph, + const ComputeCapability& optimization_cc, + ComputeCapability& cc_to_update, + const GraphOptimizerRegistry& graph_optimizer_registry) { + std::string optimizer_name = kConstantFoldingDQ; + std::unordered_set original_initializers_to_remove; + std::unordered_set new_initializers_to_add; + InlinedHashSet dq_node_index_set; + + // iterate the nodes in node_to_optimize to: + // 1. get original initializers to remove + // 2. add new initializers + // 3. create dq node index set + for (const auto& index : optimization_cc.sub_graph->nodes) { + auto node = graph.GetNode(index); + if (node->OpType() != "DequantizeLinear") { + continue; + } + auto input_0 = node->InputDefs()[0]; + auto output_0 = node->OutputDefs()[0]; + original_initializers_to_remove.insert(input_0->Name()); + new_initializers_to_add.insert(output_0->Name()); + dq_node_index_set.insert(index); + } + + static auto transformer = std::make_unique(graph_optimizer_registry.GetCpuEp(), + false /*skip_dequantize_linear*/, + graph_optimizer_registry.GetSessionOptions().config_options, + dq_node_index_set); + + bool modified = false; + ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, *graph_optimizer_registry.GetLogger())); + + // update the overall ComputeCapability + std::vector updated_nodes; + for (auto index : cc_to_update.sub_graph->nodes) { + if (dq_node_index_set.find(index) != dq_node_index_set.end()) { + continue; + } + updated_nodes.push_back(index); + } + cc_to_update.sub_graph->nodes = updated_nodes; + + auto meta_def = cc_to_update.sub_graph->GetMutableMetaDef(); + std::vector updated_constant_initializers; + + for (auto constant_initializer : meta_def->constant_initializers) { + if (original_initializers_to_remove.find(constant_initializer) != original_initializers_to_remove.end()) { + continue; + } + updated_constant_initializers.push_back(constant_initializer); + } + + for (auto constant_initializer : new_initializers_to_add) { + updated_constant_initializers.push_back(constant_initializer); + } + + meta_def->constant_initializers = updated_constant_initializers; + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.h b/onnxruntime/core/optimizer/selection_and_optimization_func.h new file mode 100644 index 0000000000000..6ad62518833b0 --- /dev/null +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_optimizer_registry.h" +#include "core/framework/compute_capability.h" +#include "core/graph/graph_viewer.h" + +namespace onnxruntime { +static const std::string kConstantFoldingDQ = "ConstantFoldingDQ"; + +/** + * Optimizer's selection function: Selects a set of nodes from a given graph for optimization. Additional key/value strings can be provided to configure the optimizer. + * If needed, use graph_optimizer_registry to access the session options, the CPU EP and the logger. + * + * Optimizer's optimization function: Gets the nodes in ComputeCapability from nodes_to_optimize. Use graph_optimizer_registry to access the session options, the CPU EP + * and the logger if needed to create the optimizer. Run optimization on the nodes/subgraph, and finally, update the ComputeCapability. + * + */ + +struct ConstantFoldingDQFuncs { + static std::vector> Select(const GraphViewer& graph_viewer, + const KeyValueConfig& configs, + const GraphOptimizerRegistry& graph_optimizer_registry); + static Status Optimize(Graph& graph, + const ComputeCapability& optimization_cc, + ComputeCapability& cc_to_update, + const GraphOptimizerRegistry& graph_optimizer_registry); +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index ede476ff74d1b..def1d5e4b704c 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -153,6 +153,7 @@ std::shared_ptr ACLExecutionProvider::GetKernelRegistry() const std::vector> ACLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant*) const { std::vector> result; for (const auto& node : graph.Nodes()) { diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.h b/onnxruntime/core/providers/acl/acl_execution_provider.h index d635e56add30b..80e4aaaf021e3 100755 --- a/onnxruntime/core/providers/acl/acl_execution_provider.h +++ b/onnxruntime/core/providers/acl/acl_execution_provider.h @@ -39,6 +39,7 @@ class ACLExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; Status OnRunStart(const onnxruntime::RunOptions&) override; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 07e83933a890c..be09eefba791b 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1254,6 +1254,7 @@ GetSubGraphPartition(const std::vector& topological_order, const std: std::vector> CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant*) const { std::vector> result; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 5ff935463a1c1..f28ae77e49f83 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -56,6 +56,7 @@ class CANNExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; Status Compile(const std::vector& fused_nodes_and_graphs, diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index 3fa3868267c9b..cc7beed6bb298 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -39,6 +39,7 @@ CoreMLExecutionProvider::~CoreMLExecutionProvider() {} std::vector> CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 0609bf6af726d..574ae1fc0106b 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -20,6 +20,7 @@ class CoreMLExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index c65dd2a04bf55..b33b1f189594b 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -244,7 +244,7 @@ static Status ConcatenateCpuOutput(void* /*stream*/, // we can't easily use a C++ template for the tensor element type, // so use a span for some protection but work in bytes - gsl::span output_span = gsl::make_span(static_cast(output), + gsl::span output_span = gsl::make_span(static_cast(output), output_size_in_bytes); for (size_t i = 0, num_iterations = per_iteration_output.size(); i < num_iterations; ++i) { @@ -257,7 +257,7 @@ static Status ConcatenateCpuOutput(void* /*stream*/, " Expected:", per_iteration_shape, " Got:", iteration_data.Shape()); } - auto src = gsl::make_span(static_cast(iteration_data.DataRaw()), + auto src = gsl::make_span(static_cast(iteration_data.DataRaw()), bytes_per_iteration); auto dst = output_span.subspan(i * bytes_per_iteration, bytes_per_iteration); gsl::copy(src, dst); diff --git a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc index 03b39e19ed748..f3c6b18f8e753 100644 --- a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc +++ b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc @@ -34,17 +34,18 @@ ONNX_OPERATOR_KERNEL_EX( ConvInteger); Status ConvInteger::Compute(OpKernelContext* context) const { - size_t num_inputs = OpKernel::Node().InputDefs().size(); + const auto input_defs = Node().InputDefs(); + size_t num_inputs = input_defs.size(); const auto* X = context->Input(0); const auto* W = context->Input(1); uint8_t input_offset = 0; uint8_t filter_offset = 0; - if (num_inputs >= 3) { + if (num_inputs >= 3 && input_defs[2]->Exists()) { const auto* X_Zero_Point = context->Input(2); ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1."); input_offset = *(X_Zero_Point->Data()); } - if (num_inputs >= 4) { + if (num_inputs >= 4 && input_defs[3]->Exists()) { const auto* W_Zero_Point = context->Input(3); ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now."); filter_offset = *(W_Zero_Point->Data()); diff --git a/onnxruntime/core/providers/cuda/controlflow/loop.cc b/onnxruntime/core/providers/cuda/controlflow/loop.cc index 3295b73a800c9..d66de7c74e647 100644 --- a/onnxruntime/core/providers/cuda/controlflow/loop.cc +++ b/onnxruntime/core/providers/cuda/controlflow/loop.cc @@ -84,10 +84,10 @@ static Status ConcatenateGpuOutput(void* stream, std::vector& per_iter CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cur_output, iteration_data.DataRaw(), bytes_per_iteration, cudaMemcpyDeviceToDevice, static_cast(stream))); - cur_output = static_cast((static_cast(cur_output) + bytes_per_iteration)); + cur_output = static_cast((static_cast(cur_output) + bytes_per_iteration)); } - ORT_ENFORCE(static_cast(cur_output) - static_cast(output) == output_size_in_bytes, + ORT_ENFORCE(static_cast(cur_output) - static_cast(output) == output_size_in_bytes, "Concatenation did not fill output buffer as expected."); return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b675c08e5f804..54fb4429c0536 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2660,6 +2660,7 @@ std::unique_ptr CUDAExecutionProvider::GetDataTransf std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const { std::vector> result; const logging::Logger& logger = *GetLogger(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 79a48e7cb89e1..a75e81f1f0c6d 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -73,6 +73,7 @@ class CUDAExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* resource_accountant) const override; int GetDeviceId() const override { return info_.device_id; } diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index cbf745d3c7b4f..a38fe1efad540 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -290,16 +290,16 @@ Status Upsample::BaseCompute(OpKernelContext* context, scales_div[i] = fast_divmod(gsl::narrow_cast(ceil(scales[i]))); } - UpampleImpl(Stream(context), - mode_, - rank, - (UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0, - input_strides, - output_div_pitches, - scales_div, - reinterpret_cast(X->Data()), - reinterpret_cast(Y->MutableData()), - output_count); + UpsampleImpl(Stream(context), + mode_, + rank, + (UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0, + input_strides, + output_div_pitches, + scales_div, + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu b/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu index d1c2ae6332994..24aeada559979 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/upsample_impl.cu @@ -8,12 +8,12 @@ namespace onnxruntime { namespace cuda { template -__global__ void _UpampleNearestKernel(const TArray input_pitches, - const TArray output_div_pitches, - const TArray scales_div, - const T* __restrict__ input_data, - T* __restrict__ output_data, - const size_t N) { +__global__ void _UpsampleNearestKernel(const TArray input_pitches, + const TArray output_div_pitches, + const TArray scales_div, + const T* __restrict__ input_data, + T* __restrict__ output_data, + const size_t N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG input_index = 0; CUDA_LONG output_index = id; @@ -36,13 +36,13 @@ __global__ void _UpampleNearestKernel(const TArray input_pitches, // This is the common use-case where the 4-D input (batched multi-channel images) // is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale] template -__global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2, - const TArray input_pitches, - const TArray output_div_pitches, - const TArray scales_div, - const T* __restrict__ input_data, - T* __restrict__ output_data, - const size_t N) { +__global__ void _UpsampleBilinear4DInputKernel(const int64_t input_dim2, + const TArray input_pitches, + const TArray output_div_pitches, + const TArray scales_div, + const T* __restrict__ input_data, + T* __restrict__ output_data, + const size_t N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG input_index = 0; @@ -95,13 +95,13 @@ __global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2, // The following method supports a 2-D input in 'Linear mode' template -__global__ void _UpampleBilinear2DInputKernel(const int64_t input_dim0, - const TArray input_pitches, - const TArray output_div_pitches, - const TArray scales_div, - const T* __restrict__ input_data, - T* __restrict__ output_data, - const size_t N) { +__global__ void _UpsampleBilinear2DInputKernel(const int64_t input_dim0, + const TArray input_pitches, + const TArray output_div_pitches, + const TArray scales_div, + const T* __restrict__ input_data, + T* __restrict__ output_data, + const size_t N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); CUDA_LONG input_index = 0; @@ -147,32 +147,32 @@ __global__ void _UpampleBilinear2DInputKernel(const int64_t input_dim0, } template -void UpampleImpl(cudaStream_t stream, - const onnxruntime::UpsampleMode upsample_mode, - const size_t rank, - const int64_t input_dim2, - const TArray& input_pitches, - const TArray& output_div_pitches, - const TArray& scales_div, - const T* input_data, - T* output_data, - const size_t N) { +void UpsampleImpl(cudaStream_t stream, + const onnxruntime::UpsampleMode upsample_mode, + const size_t rank, + const int64_t input_dim2, + const TArray& input_pitches, + const TArray& output_div_pitches, + const TArray& scales_div, + const T* input_data, + T* output_data, + const size_t N) { int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); if (onnxruntime::UpsampleMode::NN == upsample_mode) { if (rank == 4) { - _UpampleNearestKernel<<>>( + _UpsampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 3) { - _UpampleNearestKernel<<>>( + _UpsampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 2) { - _UpampleNearestKernel<<>>( + _UpsampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 1) { - _UpampleNearestKernel<<>>( + _UpsampleNearestKernel<<>>( input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else { @@ -180,11 +180,11 @@ void UpampleImpl(cudaStream_t stream, } } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode) { if (rank == 4) { - _UpampleBilinear4DInputKernel<<>>( + _UpsampleBilinear4DInputKernel<<>>( input_dim2, input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else if (rank == 2) { - _UpampleBilinear2DInputKernel<<>>( + _UpsampleBilinear2DInputKernel<<>>( input_dim2, input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } else { @@ -197,17 +197,17 @@ void UpampleImpl(cudaStream_t stream, } } -#define SPECIALIZED_IMPL(T) \ - template void UpampleImpl(cudaStream_t stream, \ - const onnxruntime::UpsampleMode upsample_mode, \ - const size_t rank, \ - const int64_t input_dim2, \ - const TArray& input_pitches, \ - const TArray& output_div_pitches, \ - const TArray& scales_div, \ - const T* input_data, \ - T* output_data, \ - const size_t N); +#define SPECIALIZED_IMPL(T) \ + template void UpsampleImpl(cudaStream_t stream, \ + const onnxruntime::UpsampleMode upsample_mode, \ + const size_t rank, \ + const int64_t input_dim2, \ + const TArray& input_pitches, \ + const TArray& output_div_pitches, \ + const TArray& scales_div, \ + const T* input_data, \ + T* output_data, \ + const size_t N); SPECIALIZED_IMPL(float) SPECIALIZED_IMPL(double) diff --git a/onnxruntime/core/providers/cuda/tensor/upsample_impl.h b/onnxruntime/core/providers/cuda/tensor/upsample_impl.h index 250ec6b272e34..fb47ad8301615 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/upsample_impl.h @@ -11,16 +11,16 @@ namespace onnxruntime { namespace cuda { template -void UpampleImpl(cudaStream_t stream, - const onnxruntime::UpsampleMode upsample_mode, - const size_t rank, - const int64_t input_dim2, - const TArray& input_pitches, - const TArray& output_div_pitches, - const TArray& scales_div, - const T* input_data, - T* output_data, - const size_t N); +void UpsampleImpl(cudaStream_t stream, + const onnxruntime::UpsampleMode upsample_mode, + const size_t rank, + const int64_t input_dim2, + const TArray& input_pitches, + const TArray& output_div_pitches, + const TArray& scales_div, + const T* input_data, + T* output_data, + const size_t N); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 9d23b8b950272..868b2103586f9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -93,12 +93,13 @@ namespace Dml ExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphOptimizerRegistry& graph_optimizer_registry, onnxruntime::IResourceAccountant* resource_accountant) const { #ifdef ENABLE_GRAPH_COMPILATION - return m_impl->GetCapability(graph, kernel_lookup, resource_accountant, *GetLogger()); + return m_impl->GetCapability(graph, kernel_lookup, graph_optimizer_registry, resource_accountant, *GetLogger()); #else - return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup, resource_accountant); + return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup, graph_optimizer_registry, resource_accountant); #endif } @@ -878,6 +879,7 @@ namespace Dml ExecutionProviderImpl::GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphOptimizerRegistry& /* graph_optimizer_registry */, onnxruntime::IResourceAccountant*, const onnxruntime::logging::Logger& logger) const { uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 7f420f8850001..aa3d8b0b4a409 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -13,6 +13,7 @@ namespace onnxruntime { class IResourceAccountant; +class GraphOptimizerRegistry; } namespace WRL { @@ -93,6 +94,7 @@ namespace Dml GetCapability( const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphOptimizerRegistry& graph_optimizer_registry, onnxruntime::IResourceAccountant* resource_accountant, const onnxruntime::logging::Logger& logger) const; @@ -288,6 +290,7 @@ namespace Dml std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphOptimizerRegistry& /* graph_optimizer_registry */, onnxruntime::IResourceAccountant* resource_accountant) const final override; onnxruntime::common::Status OnSessionInitializationEnd() override diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index 4da82b351f1d6..d0e5b0b1588ef 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -147,6 +147,7 @@ std::vector> DnnlExecutionProvider::GetSupportedNodes(con std::vector> DnnlExecutionProvider::GetCapability( const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // follow from coreml ep's Getcapability diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index bde18e139f2a3..8f951efef2a94 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -25,6 +25,7 @@ class DnnlExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, onnxruntime::IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 9d00436150286..d8e24ff1f5053 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -791,6 +791,7 @@ std::vector JsExecutionProvider::CreatePreferredAllocators() { std::vector> JsExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // `tenative_candidates` is a subset of `candidates`. diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 4bead50fc782e..c87303209c689 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -45,6 +45,7 @@ class JsExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 1558d22137c05..9a694b03387ae 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -993,6 +993,7 @@ GetPartitionedSubgraphs(const std::vector& topological_order, std::vector> MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; auto model = graph_viewer.CreateModel(*GetLogger()); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index d6af991f9b77e..7c89b5ec544a1 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -69,6 +69,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index 27bd584e2d3c6..28cfde817a620 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -81,6 +81,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {} std::vector> NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; const logging::Logger& logger = *GetLogger(); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h index ebf9372eb668d..a2269fdd89436 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h @@ -26,6 +26,7 @@ class NnapiExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 9d4ad88e2c2b3..d026ce386e5c3 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -647,7 +647,7 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe const auto& out_name = item.first; auto node = item.second; Ort::UnownedValue output_tensor = GetOutputTensor(context, - std::move(out_name), + out_name, subgraph_context_.output_names, node); auto mem_info = output_tensor.GetTensorMemoryInfo(); diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 12c16e9c9b8f6..6482a07ee92bc 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -107,6 +107,7 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { std::vector> OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index bbcca583b074b..020aec16e507c 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -51,6 +51,7 @@ class OpenVINOExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 3df231e53e7c0..d85277627a3de 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -198,35 +198,13 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, return Status::OK(); } -// Figure out the real context cache file path -// return true if context cache file exists -bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, - const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path) { - // always try the path set by user first, it's the only way to set it if load model from memory - if (!customer_context_cache_path.empty()) { - context_cache_path = ToPathString(customer_context_cache_path); - } else if (!model_pathstring.empty()) { // model loaded from file - if (is_qnn_ctx_model) { - // it's a context cache model, just use the model path - context_cache_path = model_pathstring; - } else if (!model_pathstring.empty()) { - // this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); - } - } - - return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); -} - Status CreateEPContextNodes(Model* model, unsigned char* buffer, uint64_t buffer_size, const std::string& sdk_build_version, const std::vector& fused_nodes_and_graphs, const QnnModelLookupTable& qnn_models, - const onnxruntime::PathString& context_cache_path, + const onnxruntime::PathString& context_model_path, bool qnn_context_embed_mode, uint64_t max_spill_fill_buffer_size, const logging::Logger& logger) { @@ -262,7 +240,19 @@ Status CreateEPContextNodes(Model* model, std::string cache_payload(buffer, buffer + buffer_size); ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload); } else { - onnxruntime::PathString context_bin_path = context_cache_path + ToPathString("_" + graph_name + ".bin"); + onnxruntime::PathString context_bin_path; + auto pos = context_model_path.find_last_of(ORT_TSTR(".")); + if (pos != std::string::npos) { + context_bin_path = context_model_path.substr(0, pos); + } else { + context_bin_path = context_model_path; + } + std::string graph_name_in_file(graph_name); + auto name_pos = graph_name_in_file.find_first_of(kQnnExecutionProvider); + if (name_pos != std::string::npos) { + graph_name_in_file.replace(name_pos, strlen(kQnnExecutionProvider), ""); + } + context_bin_path = context_bin_path + ToPathString(graph_name_in_file + ".bin"); std::string context_cache_name(std::filesystem::path(context_bin_path).filename().string()); std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary); if (!of_stream) { diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 3dfa0ae21001b..c54cd3ca6e90c 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -38,11 +38,6 @@ Status CreateNodeArgs(const std::vector& names, std::vector& node_args, onnxruntime::Graph& graph); -bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, - const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path); - Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, @@ -67,7 +62,7 @@ Status CreateEPContextNodes(Model* model, const std::string& sdk_build_version, const std::vector& fused_nodes_and_graphs, const std::unordered_map>& qnn_models, - const onnxruntime::PathString& context_cache_path, + const onnxruntime::PathString& context_model_path, bool qnn_context_embed_mode, uint64_t max_spill_fill_buffer_size, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index bcde69beceef7..26d792c008edc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -470,8 +470,10 @@ Status QnnBackendManager::InitializeProfiling() { QnnProfile_Level_t qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; if (ProfilingLevel::BASIC == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; + LOGS_DEFAULT(VERBOSE) << "Profiling level set to basic."; } else if (ProfilingLevel::DETAILED == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_DETAILED; + LOGS_DEFAULT(VERBOSE) << "Profiling level set to detailed."; } Qnn_ErrorHandle_t result = qnn_interface_.profileCreate(backend_handle_, qnn_profile_level, &profile_backend_handle_); ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to create QNN profile! Error: ", QnnErrorHandleToString(result)); diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index 1fb8742f724cd..cb92e927ff65a 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -181,7 +181,9 @@ void HtpSharedMemoryAllocator::Free(void* allocation_address) { // Avoid throwing exceptions as this may be running from a destructor. try { // take ownership of shared memory and free at end of scope - auto shared_memory = WrapSharedMemoryWithUniquePtr(allocation_address, rpcmem_lib_->Api()); + const size_t allocation_offset = AllocationOffsetFromStartOfHeader(); + void* raw_allocation_address = (void*)((std::byte*)allocation_address - allocation_offset); + auto shared_memory = WrapSharedMemoryWithUniquePtr(raw_allocation_address, rpcmem_lib_->Api()); // destroy header allocation_header.~AllocationHeader(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3fc537066ae0b..a5813dc2a4adc 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -195,6 +195,10 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio share_ep_contexts_ = config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; LOGS_DEFAULT(VERBOSE) << "User specified option - share EP contexts across sessions: " << share_ep_contexts_; + + stop_share_ep_contexts_ = + config_options->GetConfigOrDefault(kOrtSessionOptionStopShareEpContexts, "0") == "1"; + LOGS_DEFAULT(VERBOSE) << "User specified option - stop share EP contexts across sessions: " << stop_share_ep_contexts_; } static const std::string BACKEND_PATH = "backend_path"; @@ -384,17 +388,27 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } - qnn_backend_manager_ = qnn::QnnBackendManager::Create( - qnn::QnnBackendManagerConfig{backend_path, - profiling_level_etw, - profiling_level, - profiling_file_path, - context_priority, - qnn_saver_path, - device_id_, - htp_arch, - soc_model, - enable_htp_weight_sharing}); + // For context binary generation with weight sharing enabled, use the QnnBackendManager from the shared context if it exits + // So that all graphs from later sessions will be compiled into the same QNN context + if (context_cache_enabled_ && share_ep_contexts_ && SharedContext::GetInstance().GetSharedQnnBackendManager()) { + qnn_backend_manager_ = SharedContext::GetInstance().GetSharedQnnBackendManager(); + // Clear the QnnBackendManager from singleton to stop the resource share + if (stop_share_ep_contexts_) { + SharedContext::GetInstance().ResetSharedQnnBackendManager(); + } + } else { + qnn_backend_manager_ = qnn::QnnBackendManager::Create( + qnn::QnnBackendManagerConfig{backend_path, + profiling_level_etw, + profiling_level, + profiling_file_path, + context_priority, + qnn_saver_path, + device_id_, + htp_arch, + soc_model, + enable_htp_weight_sharing}); + } #if defined(_WIN32) if (onnxruntime::logging::EtwRegistrationManager::SupportsETW()) { @@ -655,6 +669,7 @@ static void PartitionCtxModel(const onnxruntime::GraphViewer& graph_viewer, std::vector> QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; @@ -904,25 +919,33 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); - onnxruntime::PathString context_cache_path; + onnxruntime::PathString context_model_path; bool is_ctx_file_exist = false; if (is_qnn_ctx_model || context_cache_enabled_) { const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); - is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model, - context_cache_path_cfg_, - graph_viewer_0.ModelPath().native(), - context_cache_path); + // Figure out the EP context model path from model path or session option + GetContextOnnxModelFilePath(context_cache_path_cfg_, + graph_viewer_0.ModelPath().native(), + context_model_path); } - ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_, - "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", - "Please remove the EP context model manually if you want to re-generate it."); - if (is_qnn_ctx_model) { // Get QnnModel from EP shared contexts if (share_ep_contexts_ && SharedContext::GetInstance().HasSharedQnnModels()) { @@ -965,7 +988,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); // Create QNN context from the cached binary, deserialize the QNN graph from the binary ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, - context_cache_path, + context_model_path, qnn_backend_manager_.get(), qnn_models, logger, @@ -1025,10 +1048,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_backend_manager_->GetSdkVersion(), fused_nodes_and_graphs, qnn_models_, - context_cache_path, + context_model_path, qnn_context_embed_mode_, max_spill_fill_buffer_size, logger)); + + if (share_ep_contexts_ && !stop_share_ep_contexts_ && + nullptr == SharedContext::GetInstance().GetSharedQnnBackendManager()) { + ORT_RETURN_IF_NOT(SharedContext::GetInstance().SetSharedQnnBackendManager(qnn_backend_manager_), + "Failed to set shared QnnBackendManager."); + } } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 31c34855ca4c0..d7a5d04d22692 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -31,6 +31,7 @@ class QNNExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; Status Compile(const std::vector& fused_nodes_and_graphs, @@ -90,6 +91,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t default_rpc_control_latency_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; + bool stop_share_ep_contexts_ = false; bool enable_spill_fill_buffer_ = false; #if defined(_WIN32) onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; diff --git a/onnxruntime/core/providers/qnn/shared_context.h b/onnxruntime/core/providers/qnn/shared_context.h index 81de357dbe677..277a484ad8528 100644 --- a/onnxruntime/core/providers/qnn/shared_context.h +++ b/onnxruntime/core/providers/qnn/shared_context.h @@ -61,13 +61,39 @@ class SharedContext { return graph_exist; } + bool SetSharedQnnBackendManager(std::shared_ptr& qnn_backend_manager) { + const std::lock_guard lock(mtx_); + + if (qnn_backend_manager_ != nullptr) { + if (qnn_backend_manager_ == qnn_backend_manager) { + return true; + } + return false; + } + qnn_backend_manager_ = qnn_backend_manager; + return true; + } + + std::shared_ptr GetSharedQnnBackendManager() { + const std::lock_guard lock(mtx_); + return qnn_backend_manager_; + } + + void ResetSharedQnnBackendManager() { + const std::lock_guard lock(mtx_); + qnn_backend_manager_.reset(); + } + private: SharedContext() = default; ~SharedContext() = default; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SharedContext); + // Used for passing through QNN models (deserialized from context binary) across sessions std::vector> shared_qnn_models_; + // Used for compiling multiple models into same QNN context binary + std::shared_ptr qnn_backend_manager_; // Producer sessions can be in parallel // Consumer sessions have to be after producer sessions initialized std::mutex mtx_; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc index 10fd81786f977..e9343e2b2e06a 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc @@ -51,6 +51,7 @@ std::vector> RknpuExecutionProvider::GetSupportedNodes( std::vector> RknpuExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // Find inputs, initializers and outputs for each supported subgraph std::vector> result; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h index ce16d63e111d9..75cae37d117a0 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.h @@ -20,6 +20,7 @@ class RknpuExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 9d6e9df907ce3..49771488efc44 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -2441,6 +2441,7 @@ std::unique_ptr ROCMExecutionProvider::GetDataTransf std::vector> ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index ff2bff7c98723..2baaf2ff1a886 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -62,6 +62,7 @@ class ROCMExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const override { return info_.device_id; } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 6ff2572e5e668..9d61e1f12f5b6 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -200,6 +200,7 @@ struct SparseTensor; class TensorSeq; class SessionState; class ModelMetadefIdGenerator; +class GraphOptimizerRegistry; class If; class Loop; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 2dab9f6a402a0..90fd36ea29956 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -332,8 +332,9 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) const { - return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup, resource_accountant); + return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup, graph_optimizer_registry, resource_accountant); } common::Status IExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index a77f0cb4c27b0..83d615c1bde0a 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -105,6 +105,8 @@ using ModelMetaData = std::unordered_map; using IOnnxRuntimeOpSchemaCollectionPtr = std::shared_ptr; using IOnnxRuntimeOpSchemaRegistryList = std::list; using InitializedTensorSet = std::unordered_map; +using KeyValueConfig = std::unordered_map; +using SelectionFunc = std::function>(const GraphViewer&, const KeyValueConfig&, const GraphOptimizerRegistry&)>; struct Node__NodeIterator { virtual ~Node__NodeIterator() {} @@ -151,6 +153,10 @@ struct ConstGraphNodes_Iterator { struct ProviderHost { virtual const OrtApiBase* OrtGetApiBase() = 0; + virtual Status GetOptimizerByName(const std::string& name, + const GraphOptimizerRegistry& graph_optimizer_registry, + SelectionFunc& selection_func) = 0; + virtual void* HeapAllocate(size_t size) = 0; virtual void HeapFree(void*) = 0; @@ -253,6 +259,7 @@ struct ProviderHost { // IExecutionProvider virtual std::vector> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, const IExecutionProvider::IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) = 0; virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) = 0; @@ -627,6 +634,8 @@ struct ProviderHost { virtual std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) = 0; virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0; virtual std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability* p) = 0; + virtual void ComputeCapability__copy_optimization_func(ComputeCapability* p, ComputeCapability* selection_cc) = 0; + virtual void ComputeCapability__add_nodes_to_optimize(ComputeCapability* p, std::unique_ptr optimization_cc) = 0; // DataTransferManager virtual Status DataTransferManager__CopyTensor(const DataTransferManager* p, const Tensor& src, Tensor& dst) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index a502ce9c66f69..e2af144f455e4 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -527,6 +527,9 @@ struct ComputeCapability final { std::unique_ptr& SubGraph() { return g_host->ComputeCapability__SubGraph(this); } + void copy_optimization_func(ComputeCapability* selection_cc) { g_host->ComputeCapability__copy_optimization_func(this, selection_cc); } + void add_nodes_to_optimize(std::unique_ptr optimization_cc) { g_host->ComputeCapability__add_nodes_to_optimize(this, std::move(optimization_cc)); } + ComputeCapability() = delete; ComputeCapability(const ComputeCapability&) = delete; void operator=(const ComputeCapability&) = delete; diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc index c7fc6d3a556a7..4eae7c97f9ab0 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.cc +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.cc @@ -72,6 +72,7 @@ SNPEExecutionProvider::~SNPEExecutionProvider() {} std::vector> SNPEExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector candidates; for (auto& node_index : graph.GetNodesInTopologicalOrder()) { diff --git a/onnxruntime/core/providers/snpe/snpe_execution_provider.h b/onnxruntime/core/providers/snpe/snpe_execution_provider.h index 99033649fcbbf..4b7987b38ee93 100644 --- a/onnxruntime/core/providers/snpe/snpe_execution_provider.h +++ b/onnxruntime/core/providers/snpe/snpe_execution_provider.h @@ -19,6 +19,7 @@ class SNPEExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index e59d252793532..523ebbfae807a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2459,6 +2459,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& std::vector> TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* /* resource_accountant */) const { // Construct subgraph capability from node list std::vector> result; @@ -2664,11 +2665,61 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } } + /** + * Enable EP related L2+ graph optimizations: + * + * 1. Calls provider bridge API to lookup pre-defined optimizer by name and get selection function. + * - Example: g_host->GetOptimizerByName(optimizer_name, graph_optimizer_registry, selection_func) + * 2. Executes the selection function to obtain the selection ComputeCapability. + * - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization. + * 3. Uses the selection ComputeCapability to create the optimization ComputeCapability. + * 4. Returns the final ComputeCapability, with nodes_to_optimize set to the optimization ComputeCapability. + * + * Current available optimizations: + * - (ConstantFoldingDQ) constant folding on DQ nodes, i.e. dequantize INT32, UINT16, INT16 constant to FP32. + */ + + SelectionFunc selection_func; + std::vector> selection_cc; + + // Prepare for ConstantFoldingDQ optimizer + // Note: The NodeIndex here is the node index in the graph, not the index in node vector in supported_nodes_vector. + std::unordered_set trt_selection_node_set; // The qualified dq nodes selected by TRT EP + std::unordered_map consumer_to_dq; // consumer node -> dq node + + if (dla_enable_) { + std::string optimizer_name = "ConstantFoldingDQ"; + const std::unordered_map key_value_config; + auto status = g_host->GetOptimizerByName(optimizer_name, graph_optimizer_registry, selection_func); + if (status == Status::OK()) { + if (selection_func) { + selection_cc = selection_func(graph, key_value_config, graph_optimizer_registry); + SelectQualifiedDQNode(graph, trt_selection_node_set, consumer_to_dq); + } + } else { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Can't get optimizer " << optimizer_name; + } + } + + // Create ComputeCapability int number_of_trt_nodes = 0, subgraph_index = 0; - for (const auto& group : supported_nodes_vector) { + for (auto& group : supported_nodes_vector) { if (!group.first.empty()) { + if (!selection_cc.empty()) { + // Include DQ nodes that are filtered out by TRT parser + UpdateSupportedNodeVectorForDQ(graph, group, supported_nodes_vector, consumer_to_dq); + } + std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); + auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); + + // add optimization ComputeCapability to node_to_optimize + for (auto& cc : selection_cc) { + std::unique_ptr optimization_cc = CreateOptimizationComputeCapability(cc.get(), trt_selection_node_set, compute_capability.get()); + compute_capability->add_nodes_to_optimize(std::move(optimization_cc)); + } + + result.push_back(std::move(compute_capability)); number_of_trt_nodes += static_cast(group.first.size()); subgraph_index++; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 873826a81c51b..934cc06eed45f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -249,6 +249,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return device_id_; } @@ -592,5 +593,35 @@ class TensorrtExecutionProvider : public IExecutionProvider { * This function only creates the instance at the first time it's being called." */ nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; + + /** + * This is the helper function for ConstantFoldingDQ graph transformer. + * + * It selects the qualified/required DQ node to be optimized as well as provides a mapping table + * to help TRT EP later include the DQ node which is filtered out by TRT parser. + */ + void SelectQualifiedDQNode(const GraphViewer& graph, + std::unordered_set& selection_node_set, + std::unordered_map& consumer_to_dq) const; + + /** + * This function returns an optimization ComputeCapability that is limited to: + * 1. the DQ nodes in this individual TRT ComputeCapability + * 2. the DQ nodes that are qualified and selected by TRT EP + * + * It also needs to make sure the DQ nodes is a subset of the complete list of DQ nodes to optimize in original selection ComputeCapability. + * Finally, copy the optimization function from the original selection ComputeCapability. + */ + std::unique_ptr CreateOptimizationComputeCapability(ComputeCapability* selection_cc, + std::unordered_set& trt_selection_node_set, + ComputeCapability* trt_cc) const; + /** + * This function helps add back the DQ nodes that are filtered out by TRT parser. + * The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization. + */ + void UpdateSupportedNodeVectorForDQ(const GraphViewer& graph, + SubGraph_t& supported_node_vector, + SubGraphCollection_t& supported_nodes_vector, + std::unordered_map consumer_to_dq) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index 92fa101118506..71674f7c9c557 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -258,4 +258,133 @@ void TensorrtExecutionProvider::SetAllGraphInputs(Graph& graph) const { graph.SetInputs(graph_inputs_including_initializers); } + +/** + * This is the helper function for ConstantFoldingDQ graph transformer. + * + * It selects the qualified/required DQ node to be optimized as well as provides a mapping table + * to help TRT EP later include the DQ node which is filtered out by TRT parser. + */ +void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, + std::unordered_set& selection_node_set, + std::unordered_map& consumer_to_dq) const { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Select qualified DQ nodes ..."; + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); + for (auto index : node_index) { + auto* node = graph.GetNode(index); + if (!node) { + continue; + } + + const auto* input_def = node->InputDefs()[0]; // Get NodeArg of the initializer of the DequantizeLinear node; + auto data_type = input_def->TypeAsProto()->tensor_type().elem_type(); + auto constant_initializer = graph.IsConstantInitializer(input_def->Name(), true); + + // Node selection: (i.e. initializer -> DQ -> bias of X) + // 1. DequantizeLinear op + // 2. DQ node does not produce graph output, single consumer + // 3. The first input of DQ is constant initializer. + // 4. The data type of initializer is INT32, UINT16 or INT16 + // 5. X should be Gemm, Conv or LayerNormalization ? + if (node->OpType() == "DequantizeLinear" && + node->GetOutputEdgesCount() == 1 && + (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16 || data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) && + constant_initializer) { + const Node& consumer_node = *node->OutputNodesBegin(); + selection_node_set.insert(index); + consumer_to_dq[consumer_node.Index()] = index; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " <- " << node->Name(); + } + } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Total " << selection_node_set.size() << " DequantizeLinear node(s) are selected."; +} + +/** + * This function returns an optimization ComputeCapability that is limited to: + * 1. the DQ nodes in this individual TRT ComputeCapability + * 2. the DQ nodes that are qualified and selected by TRT EP + * + * It also needs to make sure the DQ nodes is a subset of the complete list of DQ nodes to optimize in original selection ComputeCapability. + * Finally, copy the optimization function from the original selection ComputeCapability. + */ +std::unique_ptr TensorrtExecutionProvider::CreateOptimizationComputeCapability(ComputeCapability* selection_cc, + std::unordered_set& trt_selection_node_set, + ComputeCapability* trt_cc) const { + auto sub_graph = onnxruntime::IndexedSubGraph::Create(); + std::unordered_set selection_node_set; + + for (auto index : selection_cc->SubGraph()->Nodes()) { + selection_node_set.insert(index); + } + + for (auto index : trt_cc->SubGraph()->Nodes()) { + if (selection_node_set.find(index) == selection_node_set.end()) { + continue; + } + if (trt_selection_node_set.find(index) == trt_selection_node_set.end()) { + continue; + } + sub_graph->Nodes().push_back(index); + } + auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); + compute_capability->copy_optimization_func(selection_cc); + return compute_capability; +} + +/** + * This function helps add back the DQ nodes that are filtered out by TRT parser. + * The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization. + */ +void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph, + SubGraph_t& supported_node_vector, + SubGraphCollection_t& supported_nodes_vector, + std::unordered_map consumer_to_dq) const { + if (consumer_to_dq.empty()) { + return; + } + + if (!supported_node_vector.second) { + return; + } + + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); + auto supported_nodes = supported_node_vector.first; + for (auto index : supported_nodes) { + if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) { + continue; + } + + auto dq_node_index = consumer_to_dq[node_index[index]]; + + // Check if DQ node is included in one of the subgraphs + auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool { + for (auto& node_vector : supported_nodes_vector) { + if (!node_vector.second) { + continue; + } + for (auto i : node_vector.first) { + if (node_index[i] == node_idx) { + return true; + } + } + } + return false; + }; + + // If the DQ node is already in the subgraph, do nothing. + if (in_the_subgraph_collection(dq_node_index)) { + continue; + } + + // Find the iterator pointing to the target element + auto it = std::find(node_index.begin(), node_index.end(), dq_node_index); + if (it != node_index.end()) { + // Calculate the index + size_t idx = std::distance(node_index.begin(), it); + supported_node_vector.first.push_back(idx); + auto node = graph.GetNode(dq_node_index); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser."; + } + } +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 5d2204b0b1979..ab8a95b38491d 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -51,7 +51,7 @@ const InlinedVector VitisAIExecutionProvider::GetEpContextNodes() c return ep_context_node_ptrs; } std::vector> VitisAIExecutionProvider::GetCapability( - const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, IResourceAccountant* /* resource_accountant */) const { + const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { if (graph_viewer.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 5b031ab882839..f72f8cc721fbd 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -29,6 +29,7 @@ class VitisAIExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return 0; } diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 4b9f6fae86423..3b5daef04dd50 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -62,6 +62,7 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {} std::vector> VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h index 16cfbc8a9c581..1c0b8b63a8e6c 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h @@ -40,6 +40,7 @@ class VSINPUExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; Status Compile(const std::vector& fused_nodes_and_graphs, diff --git a/onnxruntime/core/providers/webgpu/external_data_loader.cc b/onnxruntime/core/providers/webgpu/external_data_loader.cc new file mode 100644 index 0000000000000..6da9598b146f5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/external_data_loader.cc @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(__wasm__) + +#include + +#include "core/framework/tensor.h" +#include "core/providers/webgpu/external_data_loader.h" + +namespace onnxruntime { +namespace webgpu { + +bool ExternalDataLoader::CanLoad(const OrtMemoryInfo& target_memory_info) const { + return target_memory_info.device.Type() == OrtDevice::CPU || + (target_memory_info.device.Type() == OrtDevice::GPU && target_memory_info.name == WEBGPU_BUFFER); +} + +common::Status ExternalDataLoader::LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const { + ExternalDataLoadType load_type; + if (tensor.Location().device.Type() == OrtDevice::CPU) { + load_type = ExternalDataLoadType::CPU; + } else if (tensor.Location().device.Type() == OrtDevice::GPU && + tensor.Location().name == WEBGPU_BUFFER) { + load_type = ExternalDataLoadType::WEBGPU_BUFFER; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: ", tensor.Location().ToString()); + } + + return LoadWebAssemblyExternalData(env, data_file_path, data_offset, data_length, load_type, tensor.MutableDataRaw()); +} + +} // namespace webgpu +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/core/providers/webgpu/external_data_loader.h b/onnxruntime/core/providers/webgpu/external_data_loader.h new file mode 100644 index 0000000000000..7ced4e930bf7a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/external_data_loader.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(__wasm__) + +#include "core/framework/external_data_loader.h" + +namespace onnxruntime { +namespace webgpu { + +class ExternalDataLoader : public IExternalDataLoader { + public: + ExternalDataLoader() {}; + ~ExternalDataLoader() {}; + + bool CanLoad(const OrtMemoryInfo& target_memory_info) const override; + + common::Status LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/core/providers/webgpu/generator/range.cc b/onnxruntime/core/providers/webgpu/generator/range.cc index a0b65f08a5b4e..99c5a1c1b5566 100644 --- a/onnxruntime/core/providers/webgpu/generator/range.cc +++ b/onnxruntime/core/providers/webgpu/generator/range.cc @@ -23,7 +23,7 @@ Status Range::ComputeInternal(ComputeContext& context) const { return Status::OK(); } - uint32_t output_size = gsl::narrow(n); + uint32_t output_size = onnxruntime::narrow(n); RangeProgram program{}; #if defined(__GNUC__) #pragma GCC diagnostic push diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 75866513e2c7d..8a22e45f17047 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -141,7 +141,7 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { } } - uint32_t vec_size = gsl::narrow((size + 3) / 4); + uint32_t vec_size = onnxruntime::narrow((size + 3) / 4); BinaryElementwiseProgram program{kernel_name_, expression_, is_broadcast, diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc new file mode 100644 index 0000000000000..d06fc5a57eb8c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -0,0 +1,238 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/inlined_containers.h" +#include "core/providers/common.h" +#include "core/providers/webgpu/math/softmax.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Softmax, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Softmax); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Softmax, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Softmax); + +ONNX_OPERATOR_KERNEL_EX( + Softmax, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Softmax); + +static std::string MaxVector(const std::string& name, int components) { + switch (components) { + case 1: + return name; + case 2: + return "max(" + name + ".x, " + name + ".y)"; + case 3: + return "max(max(" + name + ".x, " + name + ".y), " + name + ".z)"; + case 4: + return "max(max(" + name + ".x, " + name + ".y), max(" + name + ".z, " + name + ".w))"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +static std::string SumVector(const std::string& x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +static int GetMaxComponents(int64_t size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Add input and output variables + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + int components = input.NumComponents(); + + const std::string thread_max_decl = is_fp32_ + ? "var thread_max = x_value_t(-3.402823e+38f);\n" + : "var thread_max = x_value_t(-65504.0h);\n"; + + // Define shared memory for row max and row sum + shader.AdditionalImplementation() + << "var row_max_shared : x_value_t;\n" + << "var row_sum_shared : x_value_t;\n" + << "var thread_shared : array;\n"; + + // Define helper functions to get and set values + shader.AdditionalImplementation() + << "fn getValue(row: i32, col: i32, row_stride: i32) -> x_value_t {\n" + << " let index = row * row_stride + col;\n" + << " return x[index];\n" + << "}\n" + << "fn setValue(row: i32, col: i32, row_stride: i32, value: x_value_t) {\n" + << " let index = row * row_stride + col;\n" + << " result[index] = value;\n" + << "}\n"; + + // Main function body + shader.MainFunctionBody() + << " let gindex = i32(global_idx);\n" + << " let lindex = i32(local_idx);\n" + << " const wg = " << wg_ << ";\n" + << " let row = gindex / wg;\n" + << " let cols = uniforms.packedCols;\n" + << " let row_stride : i32 = uniforms.packedCols;\n" + + // Find the row's max value + << thread_max_decl + << " for (var col = lindex; col < cols; col += wg) {\n" + << " let value = getValue(row, col, row_stride);\n" + << " thread_max = max(thread_max, value);\n" + << " }\n" + << " if (lindex < cols) {\n" + << " thread_shared[lindex] = thread_max;\n" + << " }\n" + << " workgroupBarrier();\n" + + // Reduce to find the max value + << " var reduce_size = min(cols, wg);\n" + << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (lindex < curr_size) {\n" + << " thread_shared[lindex] = max(thread_shared[lindex], thread_shared[lindex + reduce_size]);\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + << " if (lindex == 0) {\n" + << " row_max_shared = x_value_t(" << MaxVector("thread_shared[0]", components) << ");\n" + << " }\n" + << " workgroupBarrier();\n" + + // Find the row's sum of exponentials + << " var thread_sum = x_value_t(0.0);\n" + << " for (var col = lindex; col < cols; col += wg) {\n" + << " let sub_exp = exp(getValue(row, col, row_stride) - row_max_shared);\n" + << " thread_sum += sub_exp;\n" + << " }\n" + << " thread_shared[lindex] = thread_sum;\n" + << " workgroupBarrier();\n" + + // Reduce to find the sum of exponentials + << " for (var curr_size = wg >> 1; curr_size > 0; curr_size = curr_size >> 1) {\n" + << " if (lindex < curr_size) {\n" + << " thread_shared[lindex] = thread_shared[lindex] + thread_shared[lindex + curr_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + << " if (lindex == 0) {\n" + << " row_sum_shared = x_value_t(" << SumVector("thread_shared[0]", components) << ");\n" + << " }\n" + << " workgroupBarrier();\n" + + // Calculate the final value for each element in the row + << " for (var col = lindex; col < cols; col += wg) {\n" + << " let value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n" + << " setValue(row, col, row_stride, value);\n" + << " }\n"; + + return Status::OK(); +} + +Status Softmax::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + size_t input_rank = input_shape.NumDimensions(); + auto* output_tensor = context.Output(0, input_shape); + + // normalize axis + size_t axis = static_cast(HandleNegativeAxis(axis_, input_rank)); + bool is_transpose_required = axis < input_rank - 1; + + TensorShape transposed_input_shape; + Tensor transposed_input_tensor; + Tensor intermediate_output; + InlinedVector perm(input_rank); + + if (is_transpose_required) { + std::iota(std::begin(perm), std::end(perm), 0); + perm[axis] = input_rank - 1; + perm[input_rank - 1] = axis; + + TensorShapeVector transposed_input_dims; + for (auto e : perm) { + transposed_input_dims.push_back(input_shape[e]); + } + + transposed_input_shape = TensorShape(transposed_input_dims); + transposed_input_tensor = context.CreateGPUTensor(input_tensor->DataType(), transposed_input_shape); + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, *input_tensor, transposed_input_tensor)); + intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape); + } + + const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1]; + const int64_t rows = input_shape.Size() / cols; + const int64_t components = GetMaxComponents(cols); + const auto packed_cols = cols / components; + uint32_t workgroup_size = rows == 1 ? 256 : 64; + // check input tensor element type is float + const bool is_fp32 = input_tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + + SoftmaxProgram program{workgroup_size, is_fp32}; + if (is_transpose_required) { + program + .AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}) + .AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}); + } else { + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components)}}); + } + + program + .CacheHint(std::to_string(components), std::to_string(workgroup_size)) + .SetWorkgroupSize(workgroup_size) + .SetDispatchGroupSize(static_cast(rows)) + .AddUniformVariables({{static_cast(packed_cols)}}); + + ORT_RETURN_IF_ERROR(context.RunProgram(program)); + + // If transpose was required, transpose the result back + if (is_transpose_required) { + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, intermediate_output, *output_tensor)); + } + + return Status::OK(); +} +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/softmax.h b/onnxruntime/core/providers/webgpu/math/softmax.h new file mode 100644 index 0000000000000..cc97611dcb4bc --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/softmax.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class Softmax final : public WebGpuKernel { + public: + Softmax(const OpKernelInfo& info) : WebGpuKernel{info} { + int opset_ = info.node().SinceVersion(); + int64_t axis; + Status status = info.GetAttr("axis", &axis); + + if (status.IsOK()) { + axis_ = axis; + } else { + if (opset_ < 13) { + axis_ = 1; // opset-12 and below, the default axis value is 1 + } else { + axis_ = -1; // opset-13, the default axis value is -1 + } + } + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int64_t axis_; +}; + +class SoftmaxProgram final : public Program { + public: + SoftmaxProgram(uint32_t wg, bool is_fp32) + : Program{"Softmax"}, wg_{wg}, is_fp32_{is_fp32} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32}); + + private: + uint32_t wg_; + bool is_fp32_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index eaaad206ebaf5..189d7baafce6a 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -27,7 +27,7 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { if (size == 0) { return Status::OK(); } - uint32_t vec_size = gsl::narrow((size + 3) / 4); + uint32_t vec_size = onnxruntime::narrow((size + 3) / 4); UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program .AddInputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 64172021e82f1..28ad686909a47 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -23,7 +23,7 @@ static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { if (axis < -rank && axis >= rank) { ORT_THROW("invalid axis: ", axis); } - return gsl::narrow(axis < 0 ? axis + rank : axis); + return onnxruntime::narrow(axis < 0 ? axis + rank : axis); } static std::string SumVector(std::string x, int components) { @@ -92,10 +92,10 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); - const uint32_t norm_count = gsl::narrow(x_shape.SizeToDimension(axis)); + const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis)); const int64_t norm_size = x_shape.SizeFromDimension(axis); const int components = GetMaxComponents(norm_size); - const uint32_t norm_size_vectorized = gsl::narrow((norm_size + components - 1) / components); + const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); const auto scale_size = scale->Shape().Size(); const auto bias_size = (bias) ? bias->Shape().Size() : 0; diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index d1d4c242c4697..976b7927ac3dd 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -206,6 +206,26 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp } } +std::ostream& operator<<(std::ostream& os, ValidationMode mode) { + switch (mode) { + case ValidationMode::Disabled: + os << "Disabled"; + break; + case ValidationMode::WGPUOnly: + os << "WGPUOnly"; + break; + case ValidationMode::Basic: + os << "Basic"; + break; + case ValidationMode::Full: + os << "Full"; + break; + default: + os << "Unknown(" << static_cast(mode) << ")"; + } + return os; +} + namespace { TensorShape GetReducedShape(const TensorShape& shape, int component /* > 1 */) { ORT_ENFORCE(shape.NumDimensions() > 0 && shape.GetDims()[shape.NumDimensions() - 1] % component == 0, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 7bfd9e8800099..95fef36144025 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -237,6 +237,7 @@ enum class ValidationMode { Basic, Full }; +std::ostream& operator<<(std::ostream& os, ValidationMode mode); namespace details { class ProgramWrapper; diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index 1fdd312d4f0d8..7a4a873a1adf3 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -24,14 +24,14 @@ Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint auto limit_per_dimension = limits_.maxComputeWorkgroupsPerDimension; if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) { - auto size = static_cast(x) * static_cast(y) * static_cast(z); - uint32_t dispatch_avg = gsl::narrow(std::ceil(std::sqrt(size))); + double size = static_cast(x) * static_cast(y) * static_cast(z); + double dispatch_avg = std::ceil(std::sqrt(size)); if (dispatch_avg > limit_per_dimension) { - dispatch_avg = gsl::narrow(std::ceil(std::cbrt(size))); + dispatch_avg = std::ceil(std::cbrt(size)); ORT_RETURN_IF(dispatch_avg > limit_per_dimension, "The dispatch group size exceeds WebGPU maximum."); - x = y = z = dispatch_avg; + x = y = z = static_cast(dispatch_avg); } else { - x = y = dispatch_avg; + x = y = static_cast(dispatch_avg); z = 1; } } diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc new file mode 100644 index 0000000000000..eb7903e7903b6 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc @@ -0,0 +1,168 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/reduction/reduction_ops.h" +#include +#include "core/framework/data_transfer_manager.h" +#include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +#define REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceOp, begin, end) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + ReduceOp, \ + kOnnxDomain, \ + begin, end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + ReduceOp); + +#define REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceOp, version) \ + ONNX_OPERATOR_KERNEL_EX( \ + ReduceOp, \ + kOnnxDomain, \ + version, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()).InputMemoryType(OrtMemTypeCPUInput, 1), \ + ReduceOp); + +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 13, 17); +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMean, 18); + +Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + bool reduce_on_all_axes = no_op_with_empty_axes_ == false && axes_.empty(); + std::string loop_header = code_[0]; + std::string loop_body = "let current_element: input_value_t = " + input.GetByIndices("input_indices") + ";\n" + code_[1]; + std::string loop_footer = code_[2]; + const auto input_rank = input.Rank(); + for (int i = 0, l = 0; i < input_rank; ++i) { + if (reduce_on_all_axes || std::find(axes_.begin(), axes_.end(), i) != axes_.end()) { + if (keepdims_) { + l++; + } + std::stringstream ss; + std::string index = "i" + std::to_string(i); + ss << "for (var " << index << " : u32 = 0; " << index << " < " << input.IndicesGet("uniforms.input_shape", i) << "; " << index << "++) {\n"; + ss << input.IndicesSet("input_indices", i, index) << ";\n"; + ss << loop_body << "\n"; + ss << "}\n"; + loop_body = ss.str(); + } else { + std::stringstream ss; + ss << loop_header << "\n"; + std::string index = "i" + std::to_string(i); + ss << "let " << index << " = " << output.IndicesGet("output_indices", l) << ";\n"; + ss << input.IndicesSet("input_indices", i, index) << ";\n"; + loop_header = ss.str(); + l++; + } + } + std::stringstream input_indices_init_value; + for (int i = 0; i < input_rank - 1; ++i) { + input_indices_init_value << "0, "; + } + input_indices_init_value << "0"; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices: output_indices_t = " << output.OffsetToIndices("global_idx") << ";\n" + << "var input_indices: input_indices_t = input_indices_t(" << input_indices_init_value.str() << ");\n" + << loop_header << loop_body << loop_footer; + shader.MainFunctionBody() << output.SetByOffset("global_idx", "output_value"); + return Status::OK(); +} + +template +Status ReduceKernel::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + InlinedVector input_axes; + auto rank = input_tensor->Shape().NumDimensions(); + auto transform_axis = [rank](int64_t axis) { + if (axis < 0) { + axis += rank; + } + if (axis < 0 || static_cast(axis) >= rank) { + ORT_THROW("Axes values must be in the range [-rank, rank-1]. Got: ", axis); + } + return static_cast(axis); + }; + // Check if axes input is provided and copy the axes values to input_axes + if (context.InputCount() > 1) { + ORT_ENFORCE(axes_.empty(), "Axes attribute may not be specified when axes input is also provided."); + const Tensor* axes_tensor = context.Input(1); + auto size = static_cast(axes_tensor->Shape()[0]); + const auto* data = axes_tensor->Data(); + input_axes.reserve(size); + std::transform(data, data + size, std::back_inserter(input_axes), transform_axis); + } else { + input_axes.reserve(axes_.size()); + std::transform(axes_.begin(), axes_.end(), std::back_inserter(input_axes), transform_axis); + } + if (input_axes.empty()) { + if (noop_with_empty_axes_ || rank == 0) { + // If axes is empty and noop_with_empty_axes_ is true, it is a no-op according to the spec + // If input tensor is a scalar, return the input tensor as is. + // This is not correct for ReduceLogSum and ReduceSumSquare + // TODO handle these cases separately. + auto output = context.Output(0, input_tensor->Shape()); + if (output->DataRaw() != input_tensor->DataRaw()) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output)); + } + return Status::OK(); + } else { + // If axes is empty and noop_with_empty_axes_ is false, it is a reduction over all axes + input_axes.resize(rank); + std::iota(input_axes.begin(), input_axes.end(), 0); + } + } + const auto code = GetOpSpecificCode(input_tensor, input_axes.size()); + // Compute output shape + std::vector output_shape; + for (size_t i = 0; i < input_tensor->Shape().NumDimensions(); ++i) { + if (std::find(input_axes.begin(), input_axes.end(), i) != input_axes.end()) { + if (keepdims_) { + output_shape.push_back(1); + } + } else { + output_shape.push_back(input_tensor->Shape()[i]); + } + } + TensorShape output_tensor_shape(output_shape); + int64_t output_size = output_tensor_shape.Size(); + ReduceKernelProgram program("ReduceMean", keepdims_, noop_with_empty_axes_, input_axes, code); + program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) + .AddOutput({context.Output(0, output_shape), ProgramTensorMetadataDependency::TypeAndRank}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {static_cast(noop_with_empty_axes_ ? 1 : 0)}, + {input_axes}, + {static_cast(input_axes.size())}}); + + return context.RunProgram(program); +} + +ReduceOpSpecificCode ReduceMean::GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const { + const TensorShape& input_shape = input_tensor->Shape(); + size_t input_rank = input_shape.NumDimensions(); + std::stringstream ss; + ss << "var size: u32 = 1;\n" + << "for (var i: u32 = 0; i < uniforms.axes_size; i += 1) { \n" + << " let index = " << GetElementAt("uniforms.axes", "i", axes_size) << ";\n" + << " size = size * " << GetElementAt("uniforms.input_shape", "index", input_rank) << ";\n" + << "}\n" + << "let output_value = output_value_t(sum / f32(size));"; + ReduceOpSpecificCode code({"var sum = f32(0);", "sum += f32(current_element);", ss.str()}); + return code; +} + +Status ReduceMean::ComputeInternal(ComputeContext& ctx) const { + return ReduceKernel::ComputeInternal(ctx); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h new file mode 100644 index 0000000000000..e93eb06f20886 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/optional.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/reduction/reduction_kernel_base.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +namespace onnxruntime { +namespace webgpu { +// reduceOpSpecificCode is a 3-element array of strings that represent the op specific code for the reduce operation. +// The first element is the loop header, the second element is the loop body, and the third element is the loop footer. +// The loop header is the code that is executed before the loop starts. The loop body is the code that is executed for each element in the loop. +// The loop footer is the code that is executed after the loop ends. +typedef std::array ReduceOpSpecificCode; +class ReduceKernelProgram final : public Program { + public: + ReduceKernelProgram(std::string name, bool keepdims, bool no_op_with_empty_axes, const InlinedVector& axes, ReduceOpSpecificCode code) : Program{name}, keepdims_(keepdims), no_op_with_empty_axes_(no_op_with_empty_axes), axes_(axes.begin(), axes.end()), code_(code) {} + Status GenerateShaderCode(ShaderHelper& wgpuShaderModuleAddRef) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"no_op_with_empty_axes", ProgramUniformVariableDataType::Uint32}, + {"axes", ProgramUniformVariableDataType::Uint32}, + {"axes_size", ProgramUniformVariableDataType::Uint32}); + + private: + const bool keepdims_; + const bool no_op_with_empty_axes_; + InlinedVector axes_; + ReduceOpSpecificCode code_; +}; + +template +class ReduceKernel : public WebGpuKernel, public ReduceKernelBase { + protected: + using ReduceKernelBase::axes_; + using ReduceKernelBase::noop_with_empty_axes_; + using ReduceKernelBase::keepdims_; + using ReduceKernelBase::select_last_index_; + + ReduceKernel(const OpKernelInfo& info, std::string name, optional keepdims_override = {}) + : WebGpuKernel(info), + ReduceKernelBase(info, keepdims_override), + name_(name) { + } + Status ComputeInternal(ComputeContext& ctx) const; + virtual ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const = 0; + + private: + std::string name_; +}; + +class ReduceMean final : public ReduceKernel { + public: + ReduceMean(const OpKernelInfo& info) : ReduceKernel(info, "ReduceMean") {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const override; + Status ComputeInternal(ComputeContext& ctx) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 8fccbacac903b..19cab9b178b1f 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -345,9 +345,6 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha })) { ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); ss << "enable f16;\n"; - if (device_.HasFeature(wgpu::FeatureName::SubgroupsF16)) { - ss << "enable subgroups_f16;\n"; - } } if (device_.HasFeature(wgpu::FeatureName::Subgroups)) { ss << "enable subgroups;\n"; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 5e5920f582251..f8e1e0b3b8d2b 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -91,7 +91,7 @@ ShaderIndicesHelper::ShaderIndicesHelper(std::string_view name, ProgramVariableD : name_(name), type_(type), num_components_{NumberOfComponents(type)}, - rank_{gsl::narrow(dims.NumDimensions())}, + rank_{static_cast(dims.NumDimensions())}, dims_{dims}, usage_(usage), indices_type_{GetIndicesType(rank_)}, diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 8b5bede34e6d0..7f92ea4ed3776 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -69,7 +69,7 @@ Status Cast::ComputeInternal(ComputeContext& context) const { if (size == 0) { return Status::OK(); } - uint32_t vec_size = gsl::narrow((size + 3) / 4); + uint32_t vec_size = onnxruntime::narrow((size + 3) / 4); CastProgram program{to_}; program diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.h b/onnxruntime/core/providers/webgpu/tensor/cast.h index ef5c4d5d0dabe..925cd200f0aba 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.h +++ b/onnxruntime/core/providers/webgpu/tensor/cast.h @@ -26,7 +26,7 @@ class Cast final : public WebGpuKernel { int64_t to; Status status = info.GetAttr("to", &to); ORT_ENFORCE(status.IsOK(), "Attribute to is not set."); - to_ = gsl::narrow(to); + to_ = onnxruntime::narrow(to); // ignore attribute 'saturate' as float8 is not supported in WebGPU } diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 5ed8099fde05e..5cfd6c78f8929 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -104,7 +104,7 @@ Status Concat::ComputeInternal(ComputeContext& context) const { return Status::OK(); } - uint32_t output_size = gsl::narrow_cast(prepare.output_tensor->Shape().Size()); + uint32_t output_size = onnxruntime::narrow(prepare.output_tensor->Shape().Size()); size_t axis = static_cast(prepare.axis); ConcatProgram program{axis}; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 809616660aa9e..9bdebe2c1e0d3 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -42,7 +42,7 @@ Status Expand::ComputeInternal(ComputeContext& context) const { : 1; const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4 : 1; - uint32_t data_size = gsl::narrow(output_shape.Size() / components_o); + uint32_t data_size = onnxruntime::narrow(output_shape.Size() / components_o); ExpandProgram program{}; program diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc index 9f6e5f2420d86..39d07991f3c5a 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -42,7 +42,7 @@ Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { Status Gather::ComputeInternal(ComputeContext& context) const { Prepare p; ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); - uint32_t data_size = gsl::narrow(p.output_tensor->Shape().Size()); + uint32_t data_size = onnxruntime::narrow(p.output_tensor->Shape().Size()); if (data_size == 0) { return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.cc b/onnxruntime/core/providers/webgpu/tensor/pad.cc new file mode 100644 index 0000000000000..6a8bc6554b772 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/pad.cc @@ -0,0 +1,261 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/util/math.h" +#include "core/providers/webgpu/tensor/pad.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +Status PadProgram::GenerateShaderCode(ShaderHelper& shader) const { + if (!dim_value_zero_) { + shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"); + std::string constant_value_str = std::string("let constant_value = ") + + (is_float16_ ? "bitcast>(uniforms.constant_value)[0];\n" : "bitcast(uniforms.constant_value);\n"); + if (dim_value_zero_) { + // Only Constant mode needs fill output if the one dim value or mores dims' values of input are zero. + shader.MainFunctionBody() << constant_value_str + << "output[global_idx] = constant_value;\n"; + return Status::OK(); + } + + shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " var input_index = u32(0);\n" + << " var use_pad_value = false;\n" + << " var in_coord = i32(0);\n"; + + const int rank = output.Rank(); + std::string output_indices_str = "i32(" + GetElementAt("output_indices", "dim", rank) + ")"; + std::string lower_pads_str = GetElementAt("uniforms.lower_pads", "dim", rank); + std::string data_shape_str = "i32(" + GetElementAt("uniforms.data_shape", "dim", rank) + ")"; + std::string data_stride_str = rank == 1 ? "" : " * " + GetElementAt("uniforms.data_stride", "dim", rank - 1); + std::string begin_axis_statement = "in_coord = "; + std::string end_axis_statement = "in_coord = "; + std::string in_axis_statement = "in_coord = " + output_indices_str + " - " + lower_pads_str + ";\n"; + switch (mode_) { + case Mode::Constant: + begin_axis_statement = "use_pad_value = true;\n"; + end_axis_statement = "use_pad_value = true;\n"; + break; + case Mode::Edge: + begin_axis_statement += "0;\n"; + end_axis_statement += data_shape_str + " - 1;\n"; + break; + case Mode::Reflect: + begin_axis_statement += lower_pads_str + " - " + output_indices_str + ";\n"; + end_axis_statement += data_shape_str + " - 2 - (" + output_indices_str + + " - (" + lower_pads_str + " + " + data_shape_str + "));\n"; + break; + case Mode::Wrap: + begin_axis_statement += data_shape_str + " + " + output_indices_str + " - " + lower_pads_str + ";\n"; + end_axis_statement += output_indices_str + " - " + lower_pads_str + " - " + data_shape_str + ";\n"; + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported mode type: ", static_cast(mode_)); + } + + shader.MainFunctionBody() << " for (var dim = 0; dim < " << rank << " && !use_pad_value; dim++) {\n" + << " if (" << output_indices_str << " < " << lower_pads_str << ") {\n" + << " " << begin_axis_statement << " }\n" + << " else if (" << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n" + << " " << end_axis_statement << " }\n" + << " else {\n" + << " " << in_axis_statement << " }\n" + << " input_index += select(u32(in_coord)" << data_stride_str << ", u32(in_coord), dim == " << rank - 1 << ");\n" + << " }\n" + << " " << constant_value_str + << " " << output.SetByOffset("global_idx", "select(data[input_index], constant_value, use_pad_value)"); + + return Status::OK(); +} + +Status Pad::ComputeInternal(ComputeContext& context) const { + const Tensor* input_tensor = context.Input(0); + auto const& input_shape = input_tensor->Shape(); + size_t dimension_count = input_shape.NumDimensions(); + + const PadsVector* p_pads = &pads_; + const PadsVector* p_slices = &slices_; + + PadsVector pads; + PadsVector slices; + // kOnnxDomain Pad opset >= 11 (Or) kMsDomain opset == 1 + if (is_dynamic_) { + size_t data_rank = input_tensor->Shape().NumDimensions(); + + const Tensor* pads_tensor = context.Input(1); + auto pads_tensor_dims = pads_tensor->Shape().GetDims(); + ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), + "Pads tensor should be a 1D tensor of shape [2 * num_axes] " + "or a 2D tensor of shape [1, 2 * num_axes]"); + + const auto pads_data = pads_tensor->DataAsSpan(); + + // Compute Pads by applying axes if specified otherwise copy the supplied pads. + PadBase::ComputePads(context.KernelContext(), data_rank, pads_data, pads); + + // Separate out any negative pads into the slices array + PadBase::SeparateNegativeToSlices(pads, slices); + + p_pads = &pads; + p_slices = &slices; + } + + auto output_dims(input_shape.AsShapeVector()); + ORT_ENFORCE(dimension_count * 2 == p_pads->size(), "'pads' attribute has wrong number of values"); + + // Calculate output dimensions, and handle any negative padding + std::vector lower_pads(dimension_count); + for (size_t i = 0; i < dimension_count; i++) { + int64_t lower_pad = (*p_pads)[i] + (*p_slices)[i]; + int64_t upper_pad = (*p_pads)[i + dimension_count] + (*p_slices)[i + dimension_count]; + lower_pads[i] = static_cast(lower_pad); + output_dims[i] += lower_pad + upper_pad; + } + TensorShape output_shape(output_dims); + + // special case when there is a dim value of 0 in the shape. behavior depends on mode + bool dim_value_zero = input_shape.Size() == 0; + if (dim_value_zero) { + ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode_, input_shape, output_shape)); + } + + auto* output_tensor = context.Output(0, output_shape); + uint32_t output_size = onnxruntime::narrow(output_shape.Size()); + if (output_size == 0) { + // Do not need to fill output, return + return Status::OK(); + } + + // Read constant value and bitcast to uint32. + uint32_t value_uint32 = 0; + const auto data_type = input_tensor->GetElementType(); + bool is_float16 = data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + const Tensor* value_tensor = context.Input(2); + if (!is_dynamic_) { + if (is_float16) { + uint16_t value = math::floatToHalf(value_); + std::memcpy(&value_uint32, &value, sizeof(value)); + } else { + value_uint32 = *reinterpret_cast(&value_); + } + } else if (value_tensor) { + ORT_ENFORCE(value_tensor->DataType() == input_tensor->DataType() && value_tensor->Shape().Size() == 1, + "Value tensor should be a 1D tensor of size 1 with the same type as that of the input tensor"); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + int32_t value = value_tensor->Data()[0]; + value_uint32 = *reinterpret_cast(&value); + } break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + float value = value_tensor->Data()[0]; + value_uint32 = *reinterpret_cast(&value); + } break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + uint16_t value = value_tensor->Data()[0].val; + std::memcpy(&value_uint32, &value, sizeof(value)); + } break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { + value_uint32 = value_tensor->Data()[0]; + } break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input type: ", static_cast(data_type)); + } + } + + PadProgram program{mode_, dim_value_zero, is_float16}; + if (!dim_value_zero) { + program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .CacheHint(std::to_string(static_cast(mode_)), dim_value_zero) + .AddUniformVariables({{gsl::span(lower_pads.data(), lower_pads.size())}, {output_size}, {value_uint32}}); + + return context.RunProgram(program); +} + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 2, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 13, 17, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 18, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pad, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); +ONNX_OPERATOR_KERNEL_EX( + Pad, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Pad); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.h b/onnxruntime/core/providers/webgpu/tensor/pad.h new file mode 100644 index 0000000000000..58049ddb0e5ce --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/pad.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/tensor/padbase.h" + +namespace onnxruntime { +namespace webgpu { + +class PadProgram final : public Program { + public: + PadProgram(const Mode mode, bool dim_value_zero, bool is_float16) : Program{"Pad"}, + mode_{mode}, + dim_value_zero_{dim_value_zero}, + is_float16_{is_float16} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"lower_pads", ProgramUniformVariableDataType::Int32}, + {"output_size", ProgramUniformVariableDataType::Uint32}, + {"constant_value", ProgramUniformVariableDataType::Uint32}); + + private: + Mode mode_; + bool dim_value_zero_; + bool is_float16_; +}; + +class Pad final : public PadBase, public WebGpuKernel { + public: + Pad(const OpKernelInfo& info) : PadBase(info), WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc b/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc index 455e7dc54bf1d..f68ace3c1d8a1 100644 --- a/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc +++ b/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc @@ -211,7 +211,7 @@ Status ResizeNearestImpl(ComputeContext& context, onnxruntime::ResizeNearestMode nearest_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = gsl::narrow(output_shape.Size()); + uint32_t output_size = onnxruntime::narrow(output_shape.Size()); ResizeNearestProgram program{coordinate_transform_mode, nearest_mode, extrapolation_enabled, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) @@ -299,7 +299,7 @@ Status ResizeBilinearImpl(ComputeContext& context, onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = gsl::narrow(output_shape.Size()); + uint32_t output_size = onnxruntime::narrow(output_shape.Size()); ResizeBilinearProgram program{coordinate_transform_mode, extrapolation_enabled, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) @@ -413,7 +413,7 @@ Status ResizeTrilinearImpl(ComputeContext& context, onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = gsl::narrow(output_shape.Size()); + uint32_t output_size = onnxruntime::narrow(output_shape.Size()); ResizeTrilinearProgram program{coordinate_transform_mode, extrapolation_enabled, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) @@ -534,7 +534,7 @@ Status ResizeBiCubicImpl(ComputeContext& context, onnxruntime::ResizeCoordinateTransformationMode coordinate_transform_mode) { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); - uint32_t output_size = gsl::narrow(output_shape.Size()); + uint32_t output_size = onnxruntime::narrow(output_shape.Size()); ResizeBiCubicProgram program{coordinate_transform_mode, extrapolation_enabled, exclude_outside, rank}; program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) diff --git a/onnxruntime/core/providers/webgpu/tensor/split.cc b/onnxruntime/core/providers/webgpu/tensor/split.cc index 83bf832cc5b11..d93b75fa21c16 100644 --- a/onnxruntime/core/providers/webgpu/tensor/split.cc +++ b/onnxruntime/core/providers/webgpu/tensor/split.cc @@ -107,7 +107,7 @@ Status Split::ComputeInternal(ComputeContext& context) const { ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes)); - SplitProgram program{gsl::narrow_cast(axis)}; + SplitProgram program{static_cast(axis)}; program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank}); auto output_dimensions = input_shape.AsShapeVector(); @@ -120,7 +120,7 @@ Status Split::ComputeInternal(ComputeContext& context) const { program.AddOutput({output, ProgramTensorMetadataDependency::Rank}); } - uint32_t input_size = gsl::narrow(input_shape.Size()); + uint32_t input_size = onnxruntime::narrow(input_shape.Size()); // Early return if the input tensor is empty. if (input_size == 0) { return Status::OK(); @@ -130,7 +130,7 @@ Status Split::ComputeInternal(ComputeContext& context) const { std::vector sizes_in_split_axis; // sizes_in_split_axis are the cumulative sizes of the splits in the split axis. for (auto split_size : split_sizes) { - previous_sum += gsl::narrow(split_size); + previous_sum += onnxruntime::narrow(split_size); sizes_in_split_axis.push_back(previous_sum); } diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index c40ec43dd0009..0df7d1ae9fa2f 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -47,7 +47,10 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); -auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { +auto SqueezeShape(const gsl::span& shape, + const gsl::span& adjusted_perm, + TensorShapeVector& new_shape, + TensorShapeVector& new_perm) { for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] != 1) { new_shape.push_back(shape[i]); @@ -97,26 +100,28 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status Transpose::ComputeInternal(ComputeContext& context) const { - const auto* input_tensor = context.Input(0); - const TensorShape& input_shape = input_tensor->Shape(); - int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); +Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, + gsl::span permutations, + const Tensor& input, Tensor& output) { + const auto& input_shape = input.Shape(); + const auto& input_dims = input_shape.GetDims(); + int32_t rank = static_cast(input_shape.NumDimensions()); TensorShapeVector output_dims(rank); - InlinedVector default_perm(rank); - const InlinedVector* p_perm = nullptr; - ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); - TensorShape output_shape(output_dims); - auto* output_tensor = context.Output(0, output_shape); - InlinedVector new_shape{}; - InlinedVector new_perm{}; - SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm); - const bool channels_last = new_perm == InlinedVector({2, 3, 1}); - const bool channels_first = new_perm == InlinedVector({3, 1, 2}); + for (int32_t i = 0; i < rank; i++) { + output_dims[i] = input_dims[permutations[i]]; + } + + TensorShapeVector new_shape{}; + TensorShapeVector new_perm{}; + SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm); + const bool channels_last = new_perm == TensorShapeVector({2, 3, 1}); + const bool channels_first = new_perm == TensorShapeVector({3, 1, 2}); const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; auto new_input_shape = input_shape; TensorShape new_output_shape(output_dims); + if (use_shared) { new_input_shape = channels_last ? TensorShape({new_shape[0], new_shape[1] * new_shape[2]}) @@ -126,16 +131,16 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); } - uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); - TransposeProgram program{*p_perm, use_shared}; + uint32_t output_size = onnxruntime::narrow(input_shape.Size()); + TransposeProgram program{permutations, use_shared}; + if (use_shared) { program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); } - program - .CacheHint(absl::StrJoin(*p_perm, "-")) - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) + .CacheHint(absl::StrJoin(permutations, "-")) + .AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) + .AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) .SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) .AddUniformVariables({ @@ -148,5 +153,20 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { return context.RunProgram(program); } +Status Transpose::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int32_t rank = static_cast(input_shape.NumDimensions()); + + TensorShapeVector output_dims(rank); + InlinedVector default_perm(rank); + const InlinedVector* p_perm = nullptr; + ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + + return DoTranspose(context, *p_perm, *input_tensor, *output_tensor); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index 7cf5c1fe0865d..b62a419fa12bc 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -16,6 +16,8 @@ class Transpose final : public WebGpuKernel, public TransposeBase { Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { } Status ComputeInternal(ComputeContext& context) const override; + static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output); + constexpr static uint32_t TILE_SIZE = 16; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index e8cdabb9dbe40..d7272ec525296 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -127,7 +127,7 @@ Status Where::ComputeInternal(ComputeContext& context) const { ORT_RETURN_IF_ERROR(ComputeOutputShape(cond_shape, x_shape, y_shape, output_shape)); auto* output_tensor = context.Output(0, output_shape); constexpr int component = 4; - uint32_t vec_size = gsl::narrow_cast((output_shape.Size() + 3) / component); + uint32_t vec_size = onnxruntime::narrow((output_shape.Size() + 3) / component); const auto is_broadcast = !(x_shape == y_shape && y_shape == cond_shape); WhereProgram program{is_broadcast}; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 163dd691b7f16..97144573dde2d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -134,6 +134,8 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device."); } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Context is created for: Instance=" << instance_.Get() << ", Device=" << device_.Get() << "."; + // cache adapter info ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_)); // cache device limits @@ -165,7 +167,6 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi #if defined(ENABLE_PIX_FOR_WEBGPU_EP) // set pix frame generator pix_frame_generator_ = std::make_unique(instance_, - Adapter(), Device()); #else ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)"); @@ -321,9 +322,9 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { std::vector dims(expected_rank); std::vector stride(expected_rank - 1); for (size_t j = 0; j < expected_rank; ++j) { - dims[j] = gsl::narrow(shape[j]); + dims[j] = onnxruntime::narrow(shape[j]); if (j < expected_rank - 1) { - stride[j] = gsl::narrow(shape.SizeFromDimension(j + 1)); + stride[j] = onnxruntime::narrow(shape.SizeFromDimension(j + 1)); } } @@ -490,8 +491,7 @@ std::vector WebGpuContext::GetAvailableRequiredFeatures(const #endif wgpu::FeatureName::TimestampQuery, wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::Subgroups, - wgpu::FeatureName::SubgroupsF16}; + wgpu::FeatureName::Subgroups}; for (auto feature : features) { if (adapter.HasFeature(feature)) { required_features.push_back(feature); @@ -708,45 +708,46 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co WGPUInstance instance = config.instance; WGPUDevice device = config.device; - if (context_id == 0) { - // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. - ORT_ENFORCE(instance == nullptr && device == nullptr, - "WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device."); - - std::call_once(init_default_flag_, [ + std::call_once(init_default_flag_, [ #if !defined(__wasm__) - dawn_proc_table = config.dawn_proc_table + dawn_proc_table = config.dawn_proc_table #endif - ]() { - // Step.1 - setup dawn proc table (only for non-WASM build) + ]() { + // Step.1 - setup dawn proc table (only for non-WASM build) #if !defined(__wasm__) - const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); + const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); #if defined(BUILD_DAWN_MONOLITHIC_LIBRARY) - ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); + ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); #else #if !defined(USE_EXTERNAL_DAWN) - if (dawn_procs == nullptr) { - dawn_procs = &dawn::native::GetProcs(); - } + if (dawn_procs == nullptr) { + dawn_procs = &dawn::native::GetProcs(); + } #else - ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); + ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); #endif - dawnProcSetProcs(dawn_procs); + dawnProcSetProcs(dawn_procs); #endif #endif - // Step.2 - Create wgpu::Instance + // Step.2 - Create wgpu::Instance #if !defined(__wasm__) - wgpu::InstanceDescriptor instance_desc{}; - instance_desc.capabilities.timedWaitAnyEnable = true; - default_instance_ = wgpu::CreateInstance(&instance_desc); + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.capabilities.timedWaitAnyEnable = true; + default_instance_ = wgpu::CreateInstance(&instance_desc); #else - default_instance_ = wgpu::CreateInstance(nullptr); + default_instance_ = wgpu::CreateInstance(nullptr); #endif - ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance."); - }); + ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance."); + }); + + if (context_id == 0) { + // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. + ORT_ENFORCE(instance == nullptr && device == nullptr, + "WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device."); + instance = default_instance_.Get(); } else { // for context ID > 0, user must provide custom WebGPU instance and device. @@ -800,5 +801,9 @@ void CleanupWebGpuContexts() { WebGpuContextFactory::Cleanup(); } +WGPUDevice GetDevice(int context_id) { + return WebGpuContextFactory::GetContext(context_id).Device().Get(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index d44cf4674d8a3..df7f2d6dcdeab 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -23,6 +23,7 @@ #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/external_data_loader.h" #include "core/providers/webgpu/webgpu_profiler.h" namespace onnxruntime { @@ -363,7 +364,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Pad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Pad); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If); @@ -516,10 +519,10 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -625,9 +628,9 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -685,11 +688,13 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -760,6 +765,7 @@ std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { std::vector> WebGpuExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { InlinedVector candidates; // `tenative_candidates` is a subset of `candidates`. @@ -821,6 +827,12 @@ std::unique_ptr WebGpuExecutionProvider::GetDataTran return std::make_unique(context_); } +#if defined(__wasm__) +std::unique_ptr WebGpuExecutionProvider::GetExternalDataLoader() const { + return std::make_unique(); +} +#endif + WebGpuExecutionProvider::~WebGpuExecutionProvider() { WebGpuContextFactory::ReleaseContext(context_id_); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 7a0ade97aa3df..e2e23b6a307cf 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -45,10 +45,14 @@ class WebGpuExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; +#if defined(__wasm__) + std::unique_ptr GetExternalDataLoader() const override; +#endif DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc index 90b99b7b38bb1..9b287b7b7df99 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc @@ -11,7 +11,7 @@ namespace onnxruntime { namespace webgpu { -WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Adapter adapter, wgpu::Device device) { +WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Device device) { // Trivial window size for surface texture creation and provide frame concept for PIX. static constexpr uint32_t kWidth = 512u; static constexpr uint32_t kHeight = 512u; @@ -32,7 +32,7 @@ WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu:: wgpu::TextureFormat format; wgpu::SurfaceCapabilities capabilities; - surface_.GetCapabilities(adapter, &capabilities); + surface_.GetCapabilities(device.GetAdapter(), &capabilities); format = capabilities.formats[0]; wgpu::SurfaceConfiguration config; diff --git a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h index 52a7459a81eba..0d9393321284d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.h @@ -41,7 +41,7 @@ namespace webgpu { // WebGpuContext destruction. class WebGpuPIXFrameGenerator { public: - WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Adapter adapter, wgpu::Device device); + WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Device device); ~WebGpuPIXFrameGenerator(); void GeneratePIXFrame(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 60c61b2ca5665..1d779152f91f3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -151,6 +151,12 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( validation_mode, }; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Device ID: " << context_id; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUInstance: " << webgpu_instance; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUDevice: " << webgpu_device; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP DawnProcTable: " << dawn_proc_table; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP ValidationMode: " << validation_mode; + // // STEP.3 - prepare parameters for WebGPU context initialization. // diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index cbaff79f4fd4f..966deb14196dd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -219,9 +219,17 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build sign_buffer.set(0, -1.0f); sign_buffer.set(1, 1.0f); } else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - sign_buffer = emscripten::val::global("Uint16Array").new_(2); - sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f)); - sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f)); + if (model_builder.IsFloat16ArrayAvailable()) { + // Float16Array is avaliable - use Float16Array. + sign_buffer = emscripten::val::global("Float16Array").new_(2); + sign_buffer.set(0, -1.0f); + sign_buffer.set(1, 1.0f); + } else { + // Float16Array is not available - use Uint16Array instead. + sign_buffer = emscripten::val::global("Uint16Array").new_(2); + sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f)); + sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f)); + } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input data type: ", input_data_type); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index ace6519a1fc11..cf4ce216ed5b3 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -197,7 +197,8 @@ Status ModelBuilder::RegisterInitializers() { // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached // buffers in JS side. Simply create a copy to fix it. - operand = wnn_builder_.call("constant", desc, view.call("slice")); + view = view.call("slice"); + operand = wnn_builder_.call("constant", desc, view["buffer"]); } } else { // TODO: support other type. @@ -350,7 +351,8 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( emscripten::val operand = emscripten::val::object(); // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached // buffers in JS side. Simply create a copy to fix it. - operand = wnn_builder_.call("constant", desc, view.call("slice")); + view = view.call("slice"); + operand = wnn_builder_.call("constant", desc, view["buffer"]); AddOperand(name, operand); mem_persist_buffers_.push_back(std::move(persist_buffer)); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 4e2d84f481df0..1e5f859506d6b 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -30,6 +30,7 @@ class ModelBuilder { Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; // Accessors for members. + bool IsFloat16ArrayAvailable() const { return is_float16array_available_; } const GraphViewer& GetGraphViewer() const { return graph_viewer_; } InitializedTensorSet GetInitializerTensors(); @@ -68,6 +69,8 @@ class ModelBuilder { private: const GraphViewer& graph_viewer_; const logging::Logger& logger_; + const bool is_float16array_available_ = !emscripten::val::global("Float16Array").isUndefined() && + emscripten::val::global("Float16Array").hasOwnProperty("from"); emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); @@ -172,9 +175,12 @@ const emscripten::val& ModelBuilder::CreateOrGetConstant(const int32_t& data_typ } break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - buffer = emscripten::val::global("Uint16Array").new_(num_elements); + buffer = is_float16array_available_ + ? emscripten::val::global("Float16Array").new_(num_elements) + : emscripten::val::global("Uint16Array").new_(num_elements); if (value) { - buffer.call("fill", emscripten::val(PackFloat32ToUint16AsFloat16(value))); + buffer.call("fill", + emscripten::val(is_float16array_available_ ? value : PackFloat32ToUint16AsFloat16(value))); } break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 39e6520e3912b..7410ff66add30 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -56,6 +56,7 @@ WebNNExecutionProvider::~WebNNExecutionProvider() {} std::vector> WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_registries*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // For subgraph which is the attribute of the control flow nodes, part of its initializers are stored in its // ancestor graphs as common initializers shared for other subgraphs. We need to collect all of them used for diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index e806dc340d53e..b8775e717668a 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -25,6 +25,7 @@ class WebNNExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_registries*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; DataLayout GetPreferredLayout() const override { return preferred_layout_; } diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 641f8b0729d0a..ab14c083884d3 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -258,6 +258,7 @@ static void AddComputeCapabilityForEachNodeInNodeUnit( std::vector> XnnpackExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { const auto& logger = *GetLogger(); std::vector> capabilities; diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h index 152bef1a1c52c..9c4d2484f9f4b 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.h @@ -33,6 +33,7 @@ class XnnpackExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 7ef23d6c9e895..2e733f67a888c 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/graph/onnx_protobuf.h" -#include "core/common/inlined_containers.h" -#include "core/session/onnxruntime_c_api.h" -#include "core/session/ort_apis.h" -#include "core/framework/error_code_helper.h" -#include #include +#include #include + +#include "core/common/inlined_containers.h" +#include "core/framework/error_code_helper.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/abi_session_options_impl.h" #include "core/session/inference_session.h" -#include "abi_session_options_impl.h" -#include "api_utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" +#include "core/session/utils.h" OrtSessionOptions::~OrtSessionOptions() = default; diff --git a/onnxruntime/core/session/api_utils.cc b/onnxruntime/core/session/api_utils.cc deleted file mode 100644 index f7cb8520b1e5d..0000000000000 --- a/onnxruntime/core/session/api_utils.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "api_utils.h" - -onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { - const size_t str_len = str.size(); - const size_t req_size = str_len + 1; - - if (out == nullptr) { // User is querying the total output buffer size - *size = req_size; - return onnxruntime::common::Status::OK(); - } - - if (*size >= req_size) { // User provided a buffer of sufficient size - std::memcpy(out, str.data(), str_len); - out[str_len] = '\0'; - *size = req_size; - return onnxruntime::common::Status::OK(); - } - - // User has provided a buffer that is not large enough - *size = req_size; - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); -} diff --git a/onnxruntime/core/session/api_utils.h b/onnxruntime/core/session/api_utils.h deleted file mode 100644 index 27c2bbd66f8d5..0000000000000 --- a/onnxruntime/core/session/api_utils.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include - -onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size); diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 8492391172133..f583767346d88 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -20,7 +20,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "core/session/allocator_adapters.h" -#include "core/session/api_utils.h" +#include "core/session/utils.h" #include "core/session/custom_ops.h" #include "core/session/inference_session.h" #include "core/session/ort_apis.h" @@ -900,13 +900,14 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops) { // The function registers the first schema assuming all the other one are the same except the types constraints. ORT_ENFORCE(ops.size() > 0, "No kernels to registers."); - int undefined = 0; + int num_inputs_with_dynamic_type = 0; // Creation of the schema for the first kernel in ops. const OrtCustomOp* op = *ops.begin(); ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "custom op registered at runtime", 0); - auto create_type_constraint = [&ops, &schema, &undefined](const OrtCustomOp* op, int count, int i, bool is_input) { + auto create_type_constraint = [&ops, &schema, &num_inputs_with_dynamic_type]( + const OrtCustomOp* op, int count, int i, bool is_input) { onnx::OpSchema::FormalParameterOption option = onnx::OpSchema::FormalParameterOption::Single; bool is_homogeneous = true; int min_arity = 1; @@ -976,7 +977,9 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect } else { // all_types is empty. As mentioned in the previous loop, all types are allowed. schema.TypeConstraint(name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types"); - undefined++; + if (is_input) { + ++num_inputs_with_dynamic_type; + } } }; @@ -985,19 +988,21 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect create_type_constraint(op, static_cast(input_count), static_cast(i), true); } + const bool have_shape_infer_fn = op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn; + const size_t output_count = op->GetOutputTypeCount(op); for (size_t i = 0; i < output_count; i++) { const auto type = op->GetOutputType(op, i); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { if (op->GetOutputCharacteristic(op, i) == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED) { - ORT_ENFORCE(1 == undefined, - "There must be one (and only one) dynamic typed input to the custom op. " - "Its type info at runtime will be used to infer the type info of this dynamic typed output " - "which is required for the success of the model loading step. " - "More than one dynamic typed inputs are currently not supported as differing types at runtime " - "means the output type cannot be inferred without which model loading cannot proceed."); + // if there's a dynamically typed input and output we infer they both have the same type from the input. + // if that isn't the case the user must provide an output shape inference fn which must set the output type. + ORT_ENFORCE(num_inputs_with_dynamic_type == 1 || have_shape_infer_fn, + "The type of a dynamically typed output can be inferred from a single dynamically typed input, " + "or by a user provided OrtCustomOp->InferOutputShapeFn that sets the output type."); } } + create_type_constraint(op, static_cast(output_count), static_cast(i), false); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index a1903898ea7f0..e5ea562ce3535 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -38,9 +38,11 @@ #include "core/framework/utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/graph/model_editor_api_types.h" #include "core/graph/model_saving_options.h" #include "core/optimizer/graph_transformer_utils.h" #include "core/optimizer/graph_transformer.h" +#include "core/optimizer/graph_optimizer_registry.h" #include "core/optimizer/layout_transformation/layout_transformation.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" @@ -67,11 +69,11 @@ #include "core/optimizer/stft_decomposition.h" #endif #include "core/session/environment.h" -#include "core/session/user_logging_sink.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/session/user_logging_sink.h" #include "core/util/protobuf_parsing_utils.h" #include "core/util/thread_utils.h" @@ -1215,6 +1217,56 @@ common::Status InferenceSession::Load() { return LoadWithLoader(loader, "model_loading_from_saved_proto"); } +common::Status InferenceSession::Load(const OrtModel& model_editor_api_model) { + std::lock_guard l(session_mutex_); + + if (is_model_loaded_) { // already loaded + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + if (is_inited_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session has already been initialized."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; + + // need to go from unique_ptr to shared_ptr when moving into model_ + std::unique_ptr tmp_model; + ORT_RETURN_IF_ERROR(Model::LoadFromModelEditorApiModel(model_editor_api_model, + HasLocalSchema() ? &custom_schema_registries_ : nullptr, + ModelOptions(true, strict_shape_type_inference), + *session_logger_, tmp_model)); + + model_ = std::move(tmp_model); + + is_model_loaded_ = true; + + return Status::OK(); +} + +common::Status InferenceSession::ApplyUpdates(const OrtModel& model_editor_api_model) { + std::lock_guard l(session_mutex_); + + if (!is_model_loaded_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session does not contain a loaded model."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + if (is_inited_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session has already been initialized."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + return model_->MainGraph().UpdateUsingModelEditorApiModel(model_editor_api_model); +} + common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format) { // The transformer order: // 1. Ensure we inline as many functions as possible. We refer to it as Ahead Of Time (AOT) function inlining. @@ -1227,8 +1279,13 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // 6. insert cast nodes (required transformer). // 7. insert copy nodes (required transformer). + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&session_options_, + execution_providers_.Get(onnxruntime::kCpuExecutionProvider), + session_logger_); + GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry)); + // Run Ahead Of time function inlining - GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_); if (const bool disable_aot_function_inlining = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsDisableAheadOfTimeFunctionInlining, "0") == "1"; @@ -1631,7 +1688,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, SessionState& session_state, - const ConfigOptions& config_options, + const SessionOptions& sess_options, const logging::Logger& logger) { layout_transformation::TransformLayoutFunction transform_layout_fn = nullptr; @@ -1649,11 +1706,16 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - GraphPartitioner partitioner(kernel_registry_manager, providers); + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&sess_options, + providers.Get(onnxruntime::kCpuExecutionProvider), + &logger); + + GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry)); ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, - config_options, + sess_options.config_options, logger, GraphPartitioner::Mode::kOrtFormatLoad)); @@ -2096,7 +2158,7 @@ common::Status InferenceSession::Initialize() { #endif // !defined(ORT_MINIMAL_BUILD) } else { ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, - *session_state_, session_options_.config_options, *session_logger_)); + *session_state_, session_options_, *session_logger_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); @@ -3336,6 +3398,10 @@ common::Status InferenceSession::WaitForNotification(Notification* p_executor_do return Status::OK(); } +const Model& InferenceSession::GetModel() const { + return *model_; +} + SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) { ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK()); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 2c0c09dfd3e51..5b484103c9ecf 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -47,6 +47,9 @@ namespace ONNX_NAMESPACE { class ModelProto; } // namespace ONNX_NAMESPACE +// OrtModelEditorApi Model. Used to dynamically construct a model via C API at runtime. +struct OrtModel; + namespace onnxruntime { // forward declarations class CustomRegistry; class Environment; @@ -320,6 +323,27 @@ class InferenceSession { * @return OK if success. */ [[nodiscard]] common::Status Load(); + + /** + * Load an OrtModel that was dynamically constructed via OrtModelEditorApi. + * + * @param graph_api_model OrtModel from OrtModelEditorApi + * @return OK if success. + */ + [[nodiscard]] common::Status Load(const OrtModel& graph_api_model); + + /** + * Apply updates from an OrtModel that was created via OrtModelEditorApi. + * This can: + * - add nodes at the start and end of the model + * - add initializers + * - update the graph inputs/outputs + * + * @param graph_api_model OrtModel from OrtModelEditorApi + * @return OK if success. + */ + [[nodiscard]] common::Status ApplyUpdates(const OrtModel& graph_api_model); + #endif // !defined(ORT_MINIMAL_BUILD) /** @@ -571,6 +595,8 @@ class InferenceSession { #endif + const Model& GetModel() const; + protected: #if !defined(ORT_MINIMAL_BUILD) @@ -627,6 +653,12 @@ class InferenceSession { /// convenience pointer to logger. should always be the same as session_state_.Logger(); const logging::Logger* session_logger_; + // The list of execution providers. + // This MUST be prior to model_ in case there are values in the model that were allocated using an allocator + // provided by the EP. If that is the case the allocator's `free` implementation may depend on other parts of the + // EP instance. + ExecutionProviders execution_providers_; + // The model served by this inference session instance. // Currently this has to be a shared ptr because the Model::Load method // returns a shared_ptr only. Ideally factory functions should always return @@ -637,9 +669,6 @@ class InferenceSession { // The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx PathString model_location_; - // The list of execution providers. - ExecutionProviders execution_providers_; - private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); void SetLoggingManager(const SessionOptions& session_options, diff --git a/onnxruntime/core/session/model_editor_api.h b/onnxruntime/core/session/model_editor_api.h new file mode 100644 index 0000000000000..71004866bc867 --- /dev/null +++ b/onnxruntime/core/session/model_editor_api.h @@ -0,0 +1,65 @@ +namespace OrtModelEditorAPI { + +// implementation that returns the API struct +ORT_API(const OrtModelEditorApi*, GetModelEditorApi); + +// APIs to create/edit type info +ORT_API_STATUS_IMPL(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Out_ OrtTypeInfo** type_info); + +ORT_API_STATUS_IMPL(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info); + +ORT_API_STATUS_IMPL(CreateNode, const char* operator_name, const char* domain_name, _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _Inout_opt_ OrtOpAttr** attributes, _In_opt_ size_t attribs_len, + _Outptr_ OrtNode** node); + +ORT_API_STATUS_IMPL(CreateGraph, _Outptr_ OrtGraph** graph); +ORT_API_STATUS_IMPL(SetGraphInputs, _In_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); +ORT_API_STATUS_IMPL(SetGraphOutputs, _In_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); +ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor, + bool data_is_external); +ORT_API_STATUS_IMPL(AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node); + +ORT_API_STATUS_IMPL(CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model); +ORT_API_STATUS_IMPL(AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph); + +ORT_API_STATUS_IMPL(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + +// +// Model editing APIs for updating existing model by adding node/s at start or end. +// +ORT_API_STATUS_IMPL(CreateModelEditorSession, _In_ const OrtEnv* env, + _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + +ORT_API_STATUS_IMPL(CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + +ORT_API_STATUS_IMPL(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, + _Out_ int* opset); + +ORT_API_STATUS_IMPL(ApplyModelToModelEditorSession, _In_ OrtSession* session, _In_ OrtModel* model); + +ORT_API_STATUS_IMPL(FinalizeModelEditorSession, _In_ OrtSession* session, _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container); + +} // namespace OrtModelEditorAPI diff --git a/onnxruntime/core/session/model_editor_c_api.cc b/onnxruntime/core/session/model_editor_c_api.cc new file mode 100644 index 0000000000000..2f09b903ed941 --- /dev/null +++ b/onnxruntime/core/session/model_editor_c_api.cc @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "core/framework/error_code_helper.h" +#include "core/framework/ort_value.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" +#include "core/graph/model_editor_api_types.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/inference_session.h" +#include "core/session/model_editor_api.h" +#include "core/session/ort_apis.h" +#include "core/session/ort_env.h" +#include "core/session/utils.h" + +using namespace onnxruntime; + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info) { + API_IMPL_BEGIN + if (name == nullptr || *name == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "name cannot be null or empty string"); + } + + if (type_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_info cannot be null"); + } + + if (type_info->type != ONNX_TYPE_TENSOR) { + return OrtApis::CreateStatus(ORT_FAIL, "Only tensor types are supported currently"); + } + + if (type_info->tensor_type_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tensor_type_info cannot be null"); + } + + auto vi = std::make_unique(); + vi->name = name; + vi->type_info = type_info->Clone(); + + *value_info = vi.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateNode, const char* operator_name, const char* domain_name, + _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _Inout_opt_ OrtOpAttr** attributes, _In_opt_ size_t attribs_len, + _Outptr_ OrtNode** node) { + API_IMPL_BEGIN + auto n = std::make_unique(); + n->operator_name = operator_name; + n->domain_name = domain_name == kOnnxDomainAlias ? kOnnxDomain : domain_name; + n->node_name = node_name; + + n->input_names.reserve(input_names_len); + for (size_t i = 0; i < input_names_len; ++i) { + n->input_names.push_back(input_names[i]); + } + + n->output_names.reserve(output_names_len); + for (size_t i = 0; i < output_names_len; ++i) { + n->output_names.push_back(output_names[i]); + } + + if (attributes != nullptr) { + n->attributes.reserve(attribs_len); + for (size_t i = 0; i < attribs_len; ++i) { + n->attributes.push_back(*reinterpret_cast(attributes[i])); + // take ownership. as we took a copy that means releasing the original value + OrtApis::ReleaseOpAttr(attributes[i]); + attributes[i] = nullptr; + } + } + + *node = n.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateGraph, _Outptr_ OrtGraph** graph) { + API_IMPL_BEGIN + auto g = std::make_unique(); + + // do some reserves to reduce reallocation. if we had a hint about sizes upfront that would be optimal + g->initializers.reserve(32); + g->external_initializers.reserve(32); + g->nodes.reserve(64); + + *graph = g.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len) { + API_IMPL_BEGIN + graph->inputs.clear(); + for (size_t i = 0; i < inputs_len; ++i) { + if (inputs[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "inputs cannot contain null entries"); + } + + graph->inputs.push_back(std::unique_ptr(inputs[i])); // take ownership + inputs[i] = nullptr; + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len) { + API_IMPL_BEGIN + graph->outputs.clear(); + for (size_t i = 0; i < outputs_len; ++i) { + if (outputs[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "outputs cannot contain null entries"); + } + + graph->outputs.push_back(std::unique_ptr(outputs[i])); // take ownership + outputs[i] = nullptr; + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, + _Inout_ OrtValue* tensor, bool data_is_external) { + API_IMPL_BEGIN + if (!tensor->IsTensor()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only Tensor is currently supported."); + } + + if (!tensor->IsAllocated()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Tensor must be allocated."); + } + + const auto& t = tensor->Get(); + if (t.Location().device.Type() != OrtDevice::CPU) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only CPU based tensors are currently supported."); + } + + if (data_is_external) { + // enforce that an external initializer is not used if the data size is < 128 bytes. + // the reason for this is to avoid potential shape inferencing errors if this initializer is providing an + // input involved in that. the ONNX shape inferencing does not support external data for those values. + // e.g. Reshape's `shape` input, Reduce's `axes', Slice's `starts`, `ends`, `steps`, Clip's `min`, `max`, etc. + if (t.SizeInBytes() < 128) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "External initializer should only be used for data >= 128 bytes. " + "Please use CreateTensorAsOrtValue instead."); + } + + graph->external_initializers[name] = std::unique_ptr(tensor); // take ownership + } else { + graph->initializers[name] = std::unique_ptr(tensor); // take ownership + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node) { + API_IMPL_BEGIN + graph->nodes.push_back(std::unique_ptr(node)); // take ownership + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model) { + API_IMPL_BEGIN + auto m = std::make_unique(); + for (size_t i = 0; i < opset_entries_len; ++i) { + m->domain_to_version[domain_names[i]] = opset_versions[i]; + } + + *model = m.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph) { + API_IMPL_BEGIN + + if (graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + + model->graph = std::unique_ptr(graph); // take ownership + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { + API_IMPL_BEGIN + + std::unique_ptr sess; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment()); + + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(*model)); + + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + + *out = reinterpret_cast(sess.release()); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModelEditorSession, + _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out) { + API_IMPL_BEGIN + std::unique_ptr session; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, session)); + *out = reinterpret_cast(session.release()); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out) { + API_IMPL_BEGIN + std::unique_ptr session; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, session)); + *out = reinterpret_cast(session.release()); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::SessionGetOpsetForDomain, _In_ const OrtSession* ort_session, + _In_ const char* domain, _Out_ int* opset) { + const auto& session = *reinterpret_cast(ort_session); + const auto& domain_opset_map = session.GetModel().MainGraph().DomainToVersionMap(); + + auto it = domain_opset_map.find(domain); + if (it == domain_opset_map.cend()) { + return OrtApis::CreateStatus(ORT_FAIL, "Domain not used by model."); + } + + *opset = it->second; + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::ApplyModelToModelEditorSession, + _In_ OrtSession* session, _In_ OrtModel* model) { + API_IMPL_BEGIN + auto sess = reinterpret_cast(session); + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->ApplyUpdates(*model)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelEditorAPI::FinalizeModelEditorSession, _In_ OrtSession* session, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container) { + API_IMPL_BEGIN + auto sess = reinterpret_cast(session); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); + return nullptr; + API_IMPL_END +} + +static constexpr OrtModelEditorApi ort_model_editor_api = { + // NOTE: The C# bindings depend on the API order within this struct so all additions must be at the end, + // and no functions can be removed (the implementation needs to change to return an error). + + &OrtModelEditorAPI::CreateTensorTypeInfo, + &OrtModelEditorAPI::CreateSparseTensorTypeInfo, + &OrtModelEditorAPI::CreateMapTypeInfo, + &OrtModelEditorAPI::CreateSequenceTypeInfo, + &OrtModelEditorAPI::CreateOptionalTypeInfo, + + &OrtModelEditorAPI::CreateValueInfo, + + &OrtModelEditorAPI::CreateNode, + + &OrtModelEditorAPI::CreateGraph, + &OrtModelEditorAPI::SetGraphInputs, + &OrtModelEditorAPI::SetGraphOutputs, + &OrtModelEditorAPI::AddInitializerToGraph, + &OrtModelEditorAPI::AddNodeToGraph, + + &OrtModelEditorAPI::CreateModel, + &OrtModelEditorAPI::AddGraphToModel, + + &OrtModelEditorAPI::CreateSessionFromModel, + + &OrtModelEditorAPI::CreateModelEditorSession, + &OrtModelEditorAPI::CreateModelEditorSessionFromArray, + &OrtModelEditorAPI::SessionGetOpsetForDomain, + &OrtModelEditorAPI::ApplyModelToModelEditorSession, + &OrtModelEditorAPI::FinalizeModelEditorSession, +}; + +// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned +static_assert(offsetof(OrtModelEditorApi, FinalizeModelEditorSession) / sizeof(void*) == 19, + "Size of version 21 API cannot change"); // initial version in ORT 1.21 + +ORT_API(const OrtModelEditorApi*, OrtModelEditorAPI::GetModelEditorApi) { + return &ort_model_editor_api; +} + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4eedcd591154f..0e23d7a791bec 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1,45 +1,47 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/onnxruntime_c_api.h" -#include "core/session/allocator_adapters.h" -#include "core/session/inference_session_utils.h" -#include "core/session/IOBinding.h" -#include "core/framework/allocator.h" -#include "core/framework/error_code_helper.h" -#include "core/framework/execution_provider.h" -#include "core/framework/tensor_type_and_shape.h" -#include "core/framework/utils.h" #include #include #include +#include #include #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" -#include "core/common/status.h" #include "core/common/safeint.h" -#include "core/graph/constants.h" -#include "core/graph/graph.h" +#include "core/common/status.h" +#include "core/common/string_helper.h" #include "core/framework/allocator.h" -#include "core/framework/tensor.h" +#include "core/framework/allocator.h" +#include "core/framework/callback.h" +#include "core/framework/data_types.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/execution_provider.h" +#include "core/framework/onnxruntime_typeinfo.h" #include "core/framework/ort_value.h" +#include "core/framework/tensor.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/TensorSeq.h" +#include "core/framework/utils.h" +#include "core/graph/constants.h" +#include "core/graph/graph.h" +#include "core/graph/model_editor_api_types.h" #include "core/providers/get_execution_providers.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/allocator_adapters.h" #include "core/session/environment.h" -#include "core/framework/callback.h" -#include "core/framework/tensorprotoutils.h" -#include "core/framework/onnxruntime_typeinfo.h" #include "core/session/inference_session.h" +#include "core/session/inference_session_utils.h" +#include "core/session/IOBinding.h" +#include "core/session/lora_adapters.h" +#include "core/session/model_editor_api.h" +#include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" -#include "core/framework/data_types.h" -#include "abi_session_options_impl.h" -#include "core/framework/TensorSeq.h" -#include -#include "core/common/string_helper.h" - -#include "core/session/lora_adapters.h" +#include "core/session/utils.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" @@ -114,6 +116,72 @@ using namespace onnxruntime; auto v = (value); \ auto tensor = v->GetMutable(); +namespace { +// Create tensor. Allocates memory. Tensor owns memory. Allocator is wrapped and stored in a shared_ptr in Tensor. +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, + OrtAllocator* allocator, OrtValue& value) { + TensorShape tensor_shape(shape, shape_len); + AllocatorPtr alloc_ptr = std::make_shared(allocator); + Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value); + return nullptr; +} + +// Create Tensor with existing data. Tensor does not own memory. +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, + const int64_t* shape, size_t shape_len, + const OrtMemoryInfo* info, + void* p_data, size_t p_data_len, + OrtValue& ort_value) { + TensorShape tensor_shape(shape, shape_len); + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); + } + + size_t size_to_allocate = 0; + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); + if (!status.IsOK()) { + return ToOrtStatus(status); + } + if (size_to_allocate > p_data_len) { + std::ostringstream oss; + oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + + Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value); + return nullptr; +} + +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, + const int64_t* shape, size_t shape_len, + OrtAllocator* deleter, + void* p_data, size_t p_data_len, + OrtValue& ort_value) { + TensorShape tensor_shape(shape, shape_len); + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); + } + + size_t size_to_allocate = 0; + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); + + if (!status.IsOK()) { + return ToOrtStatus(status); + } + + if (size_to_allocate > p_data_len) { + std::ostringstream oss; + oss << "p_data_len was smaller than expected. Expected:" << size_to_allocate << " Got:" << p_data_len; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + + AllocatorPtr alloc_ptr = std::make_shared(deleter); + Tensor::InitOrtValue(ml_type, tensor_shape, p_data, std::move(alloc_ptr), ort_value); + return nullptr; +} + +} // namespace + ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out) { @@ -187,50 +255,6 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env, API_IMPL_END } -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, - _Inout_ OrtAllocator* allocator, OrtValue& value) { - TensorShape tensor_shape(shape, shape_len); - AllocatorPtr alloc_ptr = std::make_shared(allocator); - Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value); - return nullptr; -} - -ORT_STATUS_PTR CreateTensorImplForSeq(MLDataType elem_type, const int64_t* shape, size_t shape_len, Tensor& out) { - OrtAllocator* allocator; - // TODO(pranav): what allocator should be used to create the tensor here? - // for the sake of simplicity of the API using the default one here - ORT_API_RETURN_IF_ERROR(OrtApis::GetAllocatorWithDefaultOptions(&allocator)); - AllocatorPtr alloc_ptr = std::make_shared(allocator); - TensorShape tensor_shape(shape, shape_len); - out = Tensor(elem_type, tensor_shape, std::move(alloc_ptr)); - return nullptr; -} - -/** - * - * this function will create a copy of the allocator info - */ -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, const OrtMemoryInfo* info, - void* p_data, size_t p_data_len, OrtValue& ort_value) { - TensorShape tensor_shape(shape, shape_len); - if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); - } - - size_t size_to_allocate = 0; - Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); - if (!status.IsOK()) { - return ToOrtStatus(status); - } - if (size_to_allocate > p_data_len) { - std::ostringstream oss; - oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len; - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); - } - Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value); - return nullptr; -} - ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { @@ -243,6 +267,20 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemor API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out) { + API_IMPL_BEGIN + auto ml_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType(); + auto value = std::make_unique(); + ORT_API_RETURN_IF_ERROR(CreateTensorImpl(ml_type, shape, shape_len, deleter, p_data, p_data_len, *value)); + *out = value.release(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { @@ -678,97 +716,6 @@ ORT_API_STATUS_IMPL(OrtApis::EnableOrtCustomOps, _Inout_ OrtSessionOptions* opti API_IMPL_END } -namespace { -// provider either model_path, or modal_data + model_data_length. -static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, - _In_ const OrtEnv* env, - _In_opt_z_ const ORTCHAR_T* model_path, - _In_opt_ const void* model_data, - size_t model_data_length, - std::unique_ptr& sess) { - // quick check here to decide load path. InferenceSession will provide error message for invalid values. - // TODO: Could move to a helper - const Env& os_env = Env::Default(); // OS environment (!= ORT environment) - bool load_config_from_model = - os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1"; - - if (load_config_from_model) { -#if !defined(ORT_MINIMAL_BUILD) - if (model_path != nullptr) { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), - model_path); - } else { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), - model_data, static_cast(model_data_length)); - } -#else - return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build."); -#endif - } else { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment()); - } - -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - // Add custom domains - if (options && !options->custom_op_domains_.empty()) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); - } -#endif - - // Finish load - if (load_config_from_model) { -#if !defined(ORT_MINIMAL_BUILD) - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load()); -#endif - } else { - if (model_path != nullptr) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path)); - } else { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast(model_data_length))); - } - } - - return nullptr; -} - -static ORT_STATUS_PTR InitializeSession(_In_ const OrtSessionOptions* options, - _In_ std::unique_ptr<::onnxruntime::InferenceSession>& sess, - _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr) { - // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of - // byte addressable memory - std::vector> provider_list; - if (options) { - for (auto& factory : options->provider_factories) { - auto provider = factory->CreateProvider(); - provider_list.push_back(std::move(provider)); - } - } - - // register the providers - for (auto& provider : provider_list) { - if (provider) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->RegisterExecutionProvider(std::move(provider))); - } - } - - if (prepacked_weights_container != nullptr) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddPrePackedWeightsContainer( - reinterpret_cast(prepacked_weights_container))); - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Initialize()); - - return nullptr; -} - -} // namespace - ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN @@ -778,7 +725,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); *out = reinterpret_cast(sess.release()); } @@ -801,7 +748,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); *out = reinterpret_cast(sess.release()); } @@ -1208,7 +1155,6 @@ ORT_API_STATUS_IMPL(OrtApis::GetResizedStringTensorElementBuffer, _Inout_ OrtVal } namespace { - OrtStatusPtr GetTensorStringSpan(const ::OrtValue& v, gsl::span& span) { if (!v.IsAllocated()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue should contain a Tensor or a Sparse Tensor"); @@ -2112,7 +2058,6 @@ ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, _In_ const char* domain_name, _In_ } namespace { - struct ProviderBuffer { char** buffer_; char* next_write_; @@ -2342,7 +2287,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionWithPrepackedWeightsContainer, _In_ co ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess, prepacked_weights_container)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); *out = reinterpret_cast(sess.release()); } @@ -2368,7 +2313,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess, prepacked_weights_container)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); *out = reinterpret_cast(sess.release()); } @@ -2410,6 +2355,39 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSes API_IMPL_END } +ORT_API(void, OrtApis::ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo* value_info) { + delete value_info; +} + +ORT_API(void, OrtApis::ReleaseNode, _Frees_ptr_opt_ OrtNode* node) { + delete node; +} + +ORT_API(void, OrtApis::ReleaseGraph, _Frees_ptr_opt_ OrtGraph* graph) { + delete graph; +} + +ORT_API(void, OrtApis::ReleaseModel, _Frees_ptr_opt_ OrtModel* model) { + delete model; +} + +ORT_API_STATUS_IMPL(OrtApis::GetValueInfoName, _In_ const OrtValueInfo* value_info, + _Out_ const char** name) { + API_IMPL_BEGIN + *name = value_info->name.c_str(); + return nullptr; + API_IMPL_END +} +ORT_API_STATUS_IMPL(OrtApis::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, + _Outptr_ const OrtTypeInfo** type_info) { + API_IMPL_BEGIN + + *type_info = value_info->type_info.get(); + + return nullptr; + API_IMPL_END +} + ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #ifdef ENABLE_TRAINING_APIS if (version >= 13 && version <= ORT_API_VERSION) @@ -2419,13 +2397,21 @@ ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { version, ORT_API_VERSION); return nullptr; #else - ORT_UNUSED_PARAMETER(version); return nullptr; #endif } +ORT_API(const OrtModelEditorApi*, OrtApis::GetModelEditorApi) { +#if !defined(ORT_MINIMAL_BUILD) + return OrtModelEditorAPI::GetModelEditorApi(); +#else + fprintf(stderr, "The Model Editor API is not supported in a minimal build.\n"); + return nullptr; +#endif +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -2812,6 +2798,18 @@ static constexpr OrtApi ort_api_1_to_22 = { &OrtApis::SetEpDynamicOptions, // End of Version 20 - DO NOT MODIFY ABOVE (see above text for more information) + + &OrtApis::ReleaseValueInfo, + &OrtApis::ReleaseNode, + &OrtApis::ReleaseGraph, + &OrtApis::ReleaseModel, + + &OrtApis::GetValueInfoName, + &OrtApis::GetValueInfoTypeInfo, + + &OrtApis::GetModelEditorApi, + + &OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 52d3c98d526dc..9d8aeb18a782f 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -20,6 +20,10 @@ ORT_API(void, ReleaseCustomOpDomain, _Frees_ptr_opt_ OrtCustomOpDomain*); ORT_API(void, ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo*); ORT_API(void, ReleaseSequenceTypeInfo, _Frees_ptr_opt_ OrtSequenceTypeInfo*); ORT_API(void, ReleaseModelMetadata, _Frees_ptr_opt_ OrtModelMetadata*); +ORT_API(void, ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo*); +ORT_API(void, ReleaseNode, _Frees_ptr_opt_ OrtNode*); +ORT_API(void, ReleaseGraph, _Frees_ptr_opt_ OrtGraph*); +ORT_API(void, ReleaseModel, _Frees_ptr_opt_ OrtModel*); _Check_return_ _Ret_notnull_ [[nodiscard]] OrtStatus* ORT_API_CALL CreateStatus(OrtErrorCode code, _In_z_ const char* msg) NO_EXCEPTION; @@ -533,4 +537,16 @@ ORT_API_STATUS_IMPL(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* optio ORT_API_STATUS_IMPL(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); + +ORT_API_STATUS_IMPL(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); +ORT_API_STATUS_IMPL(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); + +ORT_API(const OrtModelEditorApi*, GetModelEditorApi); + +ORT_API_STATUS_IMPL(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 77c6d4c371f69..2ea4a93d21f2e 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -4,6 +4,7 @@ // This is the Onnxruntime side of the bridge to allow providers to be built as a DLL // It implements onnxruntime::ProviderHost +#include #include "core/common/inlined_containers.h" #include "core/common/path_string.h" #include "core/framework/allocator_utils.h" @@ -35,6 +36,7 @@ #include "core/graph/graph_proto_serializer.h" #include "core/framework/murmurhash3.h" #include "core/framework/model_metadef_id_generator.h" +#include "core/optimizer/graph_optimizer_registry.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" @@ -237,6 +239,21 @@ common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_na struct ProviderHostImpl : ProviderHost { const OrtApiBase* OrtGetApiBase() override { return ::OrtGetApiBase(); } + Status GetOptimizerByName(const std::string& name, + const GraphOptimizerRegistry& graph_optimizer_registry, + SelectionFunc& selection_func) override { + std::string optimizer_name(name); + + auto func = graph_optimizer_registry.GetSelectionFunc(optimizer_name); + + if (func.has_value()) { + selection_func = func.value(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get optimizer " + optimizer_name); + } + return Status::OK(); + }; + void* HeapAllocate(size_t size) override { return new uint8_t[size]; } void HeapFree(void* p) override { delete[] reinterpret_cast(p); } @@ -360,8 +377,9 @@ struct ProviderHostImpl : ProviderHost { std::vector> IExecutionProvider__GetCapability( const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, const IExecutionProvider::IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) override { - return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup, resource_accountant); + return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, resource_accountant); } common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override { @@ -797,6 +815,8 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) override { return std::make_unique(std::move(t_sub_graph)); } void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; } std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability* p) override { return p->sub_graph; } + void ComputeCapability__copy_optimization_func(ComputeCapability* p, ComputeCapability* selection_cc) override { p->optimization_func = selection_cc->optimization_func; } + void ComputeCapability__add_nodes_to_optimize(ComputeCapability* p, std::unique_ptr optimization_cc) override { p->nodes_to_optimize.push_back(std::move(optimization_cc)); } // DataTransferManager (wrapped) Status DataTransferManager__CopyTensor(const DataTransferManager* p, const Tensor& src, Tensor& dst) override { return p->CopyTensor(src, dst); } @@ -1631,6 +1651,7 @@ struct ProviderHostImpl : ProviderHost { Status LoadDynamicLibrary(onnxruntime::PathString library_name) override { return LoadDynamicLibraryFromProvider(library_name); }; #endif } provider_host_; + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc new file mode 100644 index 0000000000000..afb1ed2696c9f --- /dev/null +++ b/onnxruntime/core/session/utils.cc @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/utils.h" + +#include "core/framework/error_code_helper.h" +#include "core/framework/execution_provider.h" +#include "core/session/abi_session_options_impl.h" +// #include "core/session/environment.h" +#include "core/session/inference_session.h" +#include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" +#include "core/session/ort_env.h" + +using namespace onnxruntime; + +common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { + const size_t str_len = str.size(); + const size_t req_size = str_len + 1; + + if (out == nullptr) { // User is querying the total output buffer size + *size = req_size; + return onnxruntime::common::Status::OK(); + } + + if (*size >= req_size) { // User provided a buffer of sufficient size + std::memcpy(out, str.data(), str_len); + out[str_len] = '\0'; + *size = req_size; + return onnxruntime::common::Status::OK(); + } + + // User has provided a buffer that is not large enough + *size = req_size; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); +} + +// provider either model_path, or modal_data + model_data_length. +OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess) { + // quick check here to decide load path. InferenceSession will provide error message for invalid values. + // TODO: Could move to a helper + const Env& os_env = Env::Default(); // OS environment (!= ORT environment) + bool load_config_from_model = + os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1"; + + if (load_config_from_model) { +#if !defined(ORT_MINIMAL_BUILD) + if (model_path != nullptr) { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment(), + model_path); + } else { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment(), + model_data, static_cast(model_data_length)); + } +#else + return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build."); +#endif + } else { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment()); + } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + // Add custom domains + if (options && !options->custom_op_domains_.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); + } +#endif + + // Finish load + if (load_config_from_model) { +#if !defined(ORT_MINIMAL_BUILD) + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load()); +#endif + } else { + if (model_path != nullptr) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path)); + } else { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast(model_data_length))); + } + } + + return nullptr; +} + +OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, + _In_ onnxruntime::InferenceSession& sess, + _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) { + // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of + // byte addressable memory + std::vector> provider_list; + if (options) { + for (auto& factory : options->provider_factories) { + auto provider = factory->CreateProvider(); + provider_list.push_back(std::move(provider)); + } + } + + // register the providers + for (auto& provider : provider_list) { + if (provider) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider))); + } + } + + if (prepacked_weights_container != nullptr) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.AddPrePackedWeightsContainer( + reinterpret_cast(prepacked_weights_container))); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); + + return nullptr; +} diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h new file mode 100644 index 0000000000000..ac8ad60758b5b --- /dev/null +++ b/onnxruntime/core/session/utils.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/common.h" +#include "core/session/onnxruntime_c_api.h" + +onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size); + +struct OrtSessionOptions; +struct OrtStatus; +struct OrtPrepackedWeightsContainer; +namespace onnxruntime { +class InferenceSession; +} + +OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess); + +OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, + _In_ onnxruntime::InferenceSession& sess, + _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index ea995d4707ba3..50da0025752aa 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -204,9 +204,9 @@ def get_qnn_qdq_config( calibrate_method=calibrate_method, activation_type=activation_type, weight_type=weight_type, - op_types_to_quantize=op_types_to_quantize - if op_types_to_quantize - else list(op_types.difference(OP_TYPES_TO_EXCLUDE)), + op_types_to_quantize=( + op_types_to_quantize if op_types_to_quantize else list(op_types.difference(OP_TYPES_TO_EXCLUDE)) + ), nodes_to_exclude=nodes_to_exclude, per_channel=per_channel, use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index fa468a9676a65..d19bebad8a12c 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -240,6 +240,8 @@ def get_qdq_config( keep_removable_activations: bool = False, min_real_range: float | None = None, tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None, + calibration_providers: list[str] | None = None, + op_types_to_quantize: list[str] | None = None, nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None, extra_options: dict | None = None, ) -> StaticQuantConfig: @@ -294,6 +296,10 @@ def get_qdq_config( 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation, other nodes get the original type. If not specified, assume all consumer nodes get the converted type. + calibration_providers: Execution providers to run the session during calibration. Default is None which uses + [ "CPUExecutionProvider" ]. + op_types_to_quantize: List of operator types to quantize. If None, all operators other than Cast, DequantizeLinear, + and QuantizeLinear are quantized. nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto should be excluded from quantization. @@ -324,17 +330,20 @@ def get_qdq_config( if onnx.external_data_helper.uses_external_data(initializer): model_has_external_data = True - final_nodes_to_exclude = [] - if nodes_to_exclude is not None and isinstance(nodes_to_exclude, list): - final_nodes_to_exclude.extend(nodes_to_exclude) + op_types_to_quantize_set = set(op_types_to_quantize) if op_types_to_quantize else None + nodes_to_exclude_set = set(nodes_to_exclude) if isinstance(nodes_to_exclude, list) else set() # Iterate through nodes to get all operator types in the model and # call user's function to filter out nodes from quantization. for node in model.graph.node: - op_types.add(node.op_type) - if nodes_to_exclude is not None and callable(nodes_to_exclude): - if nodes_to_exclude(model, node): - final_nodes_to_exclude.append(node.name) + if op_types_to_quantize_set and node.op_type not in op_types_to_quantize_set: + continue + if node.name in nodes_to_exclude_set: + continue + if callable(nodes_to_exclude) and nodes_to_exclude(model, node): + nodes_to_exclude_set.add(node.name) + else: + op_types.add(node.op_type) final_extra_options = { "MinimumRealRange": min_real_range, @@ -378,11 +387,14 @@ def get_qdq_config( quant_format=QuantFormat.QDQ, activation_type=activation_type, weight_type=weight_type, - op_types_to_quantize=list(op_types.difference(op_types_to_exclude)), - nodes_to_exclude=final_nodes_to_exclude, + op_types_to_quantize=( + op_types_to_quantize if op_types_to_quantize else list(op_types.difference(op_types_to_exclude)) + ), + nodes_to_exclude=list(nodes_to_exclude_set), per_channel=per_channel, reduce_range=reduce_range, use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), + calibration_providers=calibration_providers, extra_options=final_extra_options, ) @@ -442,7 +454,7 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua if activation_type != QuantType.QFLOAT8E4M3FN and weight_type == QuantType.QFLOAT8E4M3FN: raise ValueError( f"ONNXRuntime quantization doesn't support data format: activation_type={activation_type} " - f"!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN." + "!=QuantType.QFLOAT8E4M3FN, weight_type=QuantType.QFLOAT8E4M3FN." ) if activation_type == QuantType.QFLOAT8E4M3FN and weight_type != QuantType.QFLOAT8E4M3FN: diff --git a/onnxruntime/python/tools/transformers/models/sam2/README.md b/onnxruntime/python/tools/transformers/models/sam2/README.md index e7cafeffc6231..463d154525f8f 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/README.md +++ b/onnxruntime/python/tools/transformers/models/sam2/README.md @@ -96,8 +96,7 @@ We can create a conda environment then run GPU benchmark like the following: conda create -n sam2_gpu python=3.11 -y conda activate sam2_gpu install_dir=$HOME -profiling=true -bash benchmark_sam2.sh $install_dir gpu $profiling +bash benchmark_sam2.sh $install_dir gpu ``` or create a new conda environment for CPU benchmark: @@ -107,16 +106,28 @@ conda activate sam2_cpu bash benchmark_sam2.sh $HOME cpu ``` -The first parameter is a directory to clone git repositories or install CUDA/cuDNN for benchmark. -The second parameter can be either "gpu" or "cpu", which indicates the device to run benchmark. -The third parameter is optional. Value "true" will enable profiling after running benchmarking on GPU. +The usage of the script like the following: +``` +bash benchmark_sam2.sh [profiling] [benchmarking] [nightly] [dynamo] +``` + +| Parameter| Default | Description | +|----------|----------| ------------| +| install_dir | $HOME | a directory to clone git repositories or install CUDA/cuDNN for benchmark | +| cpu_or_gpu | gpu | the device to run benchmark. The value can be either "gpu" or "cpu" | +| profiling | false | run gpu profiling | +| benchmarking | true | run benchmark | +| nightly | false | install onnxruntime nightly or official release package | +| dynamo | false | export image encoder using dynamo or not. | -The script will automatically install required packages in current conda environment, download checkpoints, export onnx, -and run demo, benchmark and optionally run profiling. +The dynamo export is experimental since graph optimization still need extra works for this model. -* The performance test result is in sam2_gpu.csv or sam2_cpu.csv, which can be loaded into Excel. -* The demo output is sam2_demo_fp16_gpu.png or sam2_demo_fp32_cpu.png. -* The profiling results are in *.nsys-rep or *.json files in current directory. Use Nvidia NSight System to view the *.nsys-rep file. +Output files: +* sam2_cpu_[timestamp].csv or sam2_gpu_[timestamp].csv has benchmark results. Use Excel to load the file to view it. +* onnxruntime_image_[encoder|decoder].json has ONNX Runtime profiling results. Use `chrome://tracing` in Chrome browser to view it. +* torch_image_[encoder|decoder].json has PyTorch profiling results. Use `chrome://tracing` in Chrome browser to view it. +* sam2_fp16_profile_image_[encoder|decoder]_[ort|torch]_gpu.[nsys-rep|sqlite] has NVTX profiling. Use Nvidia NSight System to view it. +* torch_image_encoder_compiled_code.txt has the compiled kernel code from Pytorch. ## Limitations - The exported image_decoder model does not support batch mode for now. diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py index 16d71d5057b02..3fc24d157b0cf 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py @@ -46,6 +46,7 @@ def __init__( prefer_nhwc: bool = False, warm_up: int = 5, enable_nvtx_profile: bool = False, + enable_ort_profile: bool = False, enable_torch_profile: bool = False, repeats: int = 1000, verbose: bool = False, @@ -74,6 +75,7 @@ def __init__( self.prefer_nhwc = prefer_nhwc self.warm_up = warm_up self.enable_nvtx_profile = enable_nvtx_profile + self.enable_ort_profile = enable_ort_profile self.enable_torch_profile = enable_torch_profile self.repeats = repeats self.verbose = verbose @@ -317,6 +319,7 @@ def run_test( repeats=args.repeats, warm_up=args.warm_up, enable_nvtx_profile=args.enable_nvtx_profile, + enable_ort_profile=args.enable_ort_profile, enable_torch_profile=args.enable_torch_profile, torch_compile_mode=args.torch_compile_mode, verbose=False, @@ -325,7 +328,7 @@ def run_test( if args.engine == "ort": sess_options = SessionOptions() sess_options.intra_op_num_threads = args.intra_op_num_threads - if config.enable_nvtx_profile: + if config.enable_ort_profile: sess_options.enable_profiling = True sess_options.log_severity_level = 4 sess_options.log_verbosity_level = 0 @@ -349,6 +352,8 @@ def run_test( with nvtx.annotate("one_run"): _ = session.infer(input_dict) cudart.cudaProfilerStop() + + if config.enable_ort_profile: session.ort_session.end_profiling() if repeats == 0: @@ -554,6 +559,14 @@ def _parse_arguments(): help="Enable nvtx profiling. It will add an extra run for profiling before performance test.", ) + parser.add_argument( + "--enable_ort_profile", + required=False, + default=False, + action="store_true", + help="Enable ORT profiling.", + ) + parser.add_argument( "--enable_torch_profile", required=False, diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh index 9e97867657ab9..c82b1ed31796e 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh @@ -5,7 +5,17 @@ # ------------------------------------------------------------------------- # Please refer to README.md for the prerequisites and usage of this script. -# bash benchmark_sam2.sh [profiling] +# bash benchmark_sam2.sh [profiling] [benchmarking] [nightly] [dynamo] +# Note that dynamo need onnxruntime 1.21 or later, or nightly build. +# Example: +# bash benchmark_sam2.sh $HOME gpu true true true false + +install_dir="${1:-$HOME}" +cpu_or_gpu="${2:-gpu}" +profiling="${3:-false}" +benchmarking="${4:-true}" +nightly="${5:-false}" +dynamo="${6:-false}" python="$CONDA_PREFIX/bin/python3" @@ -13,9 +23,6 @@ python="$CONDA_PREFIX/bin/python3" dir="$(cd "$(dirname "$0")" && pwd)" onnx_dir="$dir/sam2_onnx_models" -# Installation directory (default: $HOME) -install_dir="${1:-$HOME}" - if [ ! -d "$install_dir" ]; then echo "Error: install_dir '$install_dir' does not exist." exit 1 @@ -26,7 +33,6 @@ sam2_dir="$install_dir/segment-anything-2" model="sam2_hiera_large" # Default to GPU, switch to CPU if specified -cpu_or_gpu="${2:-gpu}" if [ "$cpu_or_gpu" != "gpu" ] && [ "$cpu_or_gpu" != "cpu" ]; then echo "Invalid option: $2. Please specify 'cpu' or 'gpu'." exit 1 @@ -35,52 +41,97 @@ fi echo "install_dir: $install_dir" echo "cpu_or_gpu: $cpu_or_gpu" -install_cuda_12() -{ - pushd $install_dir - wget https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run - sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath=$install_dir/cuda12.6 --silent --override --no-man-page +# Function to check if a command exists +command_exists() { + command -v "$1" >/dev/null 2>&1 +} + +# Ensure necessary tools are installed +if ! command_exists wget; then + echo "wget is not installed. Please install it and try again." + exit 1 +fi + +if ! command_exists git; then + echo "git is not installed. Please install it and try again." + exit 1 +fi + +if ! command_exists pip; then + echo "pip is not installed. Please install it and try again." + exit 1 +fi + +cuda_version=12.6 +cudnn_version=9.5 - export PATH="$install_dir/cuda12.6/bin:$PATH" - export LD_LIBRARY_PATH="$install_dir/cuda12.6/lib64:$LD_LIBRARY_PATH" - popd +# Install CUDA 12.6 +install_cuda_12() { + if ! [ -d "$install_dir/cuda${cuda_version}" ]; then + pushd "$install_dir" || exit + wget https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run + sh cuda_12.6.2_560.35.03_linux.run --toolkit --toolkitpath="$install_dir/cuda${cuda_version}" --silent --override --no-man-page + popd || exit + fi + export PATH="$install_dir/cuda${cuda_version}/bin:$PATH" + export LD_LIBRARY_PATH="$install_dir/cuda${cuda_version}/lib64:$LD_LIBRARY_PATH" } -# Function to install cuDNN 9.4 +# Install cuDNN 9.5 install_cudnn_9() { - pushd "$install_dir" - wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz - mkdir -p "$install_dir/cudnn9.5" - tar -Jxvf cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz -C "$install_dir/cudnn9.5" --strip=1 - export LD_LIBRARY_PATH="$install_dir/cudnn9.5/lib:$LD_LIBRARY_PATH" - popd + if ! [ -d "$install_dir/cudnn${cudnn_version}" ]; then + pushd "$install_dir" || exit + wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz + mkdir -p "$install_dir/cudnn${cudnn_version}" + tar -Jxvf cudnn-linux-x86_64-9.5.0.50_cuda12-archive.tar.xz -C "$install_dir/cudnn${cudnn_version}" --strip=1 + popd || exit + fi + export LD_LIBRARY_PATH="$install_dir/cudnn${cudnn_version}/lib:$LD_LIBRARY_PATH" +} + +install_ort() { + local ort="$1" + pip uninstall onnxruntime onnxruntime-gpu -y + + if [ "$nightly" = "true" ]; then + pip install flatbuffers numpy packaging protobuf sympy + pip install --pre --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ "$ort" + else + pip install "$ort" + fi + + pip install onnx onnxscript opencv-python matplotlib } # Install GPU dependencies install_gpu() { - [ ! -d "$install_dir/cuda12.6" ] && install_cuda_12 - [ ! -d "$install_dir/cudnn9.5" ] && install_cudnn_9 + install_cuda_12 + install_cudnn_9 + echo "PATH: $PATH" + echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" + + # The dynamo export need torch 2.6.0 or later. Use the latest one. + pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 --upgrade - pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 - pip install onnxruntime-gpu onnx opencv-python matplotlib + install_ort "onnxruntime-gpu" } # Install CPU dependencies install_cpu() { pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - pip install onnxruntime onnx opencv-python matplotlib + install_ort "onnxruntime" } # Clone and install SAM2 if not already installed install_sam2() { - pushd "$install_dir" + pushd "$install_dir" || exit if [ ! -d "$sam2_dir" ]; then git clone https://github.com/facebookresearch/segment-anything-2.git fi - cd "$sam2_dir" + cd "$sam2_dir" || exit pip show SAM-2 > /dev/null 2>&1 || pip install -e . [ ! -f checkpoints/sam2_hiera_large.pt ] && (cd checkpoints && sh ./download_ckpts.sh) - popd + popd || exit } # Download test image if not available @@ -90,7 +141,12 @@ download_test_image() { run_cpu_benchmark() { local repeats="$1" - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --demo + + if [ "$dynamo" = "true" ]; then + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --demo --dynamo + else + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --demo + fi for component in image_encoder image_decoder; do $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --dtype fp32 --component "$component" @@ -103,65 +159,75 @@ run_cpu_benchmark() { done } -run_gpu_benchmark() { +run_ort_gpu_benchmark() { local repeats="$1" - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp32 - $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp16 --demo - for component in image_encoder image_decoder; do - for dtype in bf16 fp32 fp16; do - $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" - done - done + if [ "$dynamo" = "true" ]; then + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp32 --dynamo + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp16 --demo --dynamo + else + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp32 + $python convert_to_onnx.py --sam2_dir "$sam2_dir" --optimize --use_gpu --dtype fp16 --demo + fi component="image_encoder" for dtype in fp32 fp16; do - #TODO: --prefer_nhwc does not help with performance - $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --use_cuda_graph + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype "$dtype" --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --use_cuda_graph done + # Test prefer_nhwc. + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --use_cuda_graph --prefer_nhwc component="image_decoder" for dtype in fp32 fp16; do # TODO: decoder does not work with cuda graph - $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype $dtype --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype "$dtype" --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" done + # Test prefer_nhwc. + $python benchmark_sam2.py --model_type "$model" --engine ort --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component "$component" --onnx_path "${onnx_dir}/${model}_${component}_${dtype}_gpu.onnx" --prefer_nhwc } -run_torch_compile_gpu_benchmark() { +run_torch_gpu_benchmark() { local repeats="$1" + # Test PyTorch eager mode. + for component in image_encoder image_decoder; do + for dtype in bf16 fp32 fp16; do + $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype "$dtype" --component "$component" + done + done + # Test different torch compile modes on image encoder for torch_compile_mode in none max-autotune reduce-overhead max-autotune-no-cudagraphs do - $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component image_encoder --torch_compile_mode $torch_compile_mode + $python benchmark_sam2.py --model_type "$model" --engine torch --sam2_dir "$sam2_dir" --repeats "$repeats" --use_gpu --dtype fp16 --component image_encoder --torch_compile_mode $torch_compile_mode done } - -# Main script -run_benchmarks() { - if [ ! -v CONDA_PREFIX ]; then - echo "Please activate conda environment before running this script." - exit 1 +install_all() { + if [ "$cpu_or_gpu" = "gpu" ]; then + install_gpu + else + install_cpu fi - - # Install dependencies - [ "$cpu_or_gpu" = "gpu" ] && install_gpu || install_cpu install_sam2 download_test_image +} - # Run benchmarks - output_csv="sam2_${cpu_or_gpu}.csv" +run_benchmarks() { + suffix=$(date +"%Y_%m_%d_%H_%M_%S") + [ "$dynamo" = "true" ] && suffix="${suffix}_dynamo" + output_csv="sam2_${cpu_or_gpu}_${suffix}.csv" if [ ! -f "$output_csv" ]; then echo "Running $cpu_or_gpu benchmark..." if [ "$cpu_or_gpu" = "gpu" ]; then - run_gpu_benchmark 1000 - run_torch_compile_gpu_benchmark 1000 + run_ort_gpu_benchmark 1000 + run_torch_gpu_benchmark 1000 else run_cpu_benchmark 100 fi cat benchmark*.csv > combined_csv awk '!x[$0]++' combined_csv > "$output_csv" + rm benchmark*.csv rm combined_csv echo "Benchmark results saved in $output_csv" else @@ -169,7 +235,16 @@ run_benchmarks() { fi } -run_benchmarks +if [ ! -v CONDA_PREFIX ]; then + echo "Please activate conda environment before running this script." + exit 1 +fi + +install_all + +if [ "$benchmarking" = "true" ]; then + run_benchmarks +fi #-------------------------------------------------------------------------- # Below are for profiling @@ -177,79 +252,100 @@ run_benchmarks # Build onnxruntime-gpu from source for profiling build_onnxruntime_gpu_for_profiling() { - pushd "$install_dir" + pushd "$install_dir" || exit if ! [ -d onnxruntime ]; then git clone https://github.com/microsoft/onnxruntime fi - cd onnxruntime - CUDA_ARCH=$(python3 -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}{cc[1]}')") - if [ -n "$CUDA_ARCH" ]; then - pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy==1.26.4 - sh build.sh --config Release --build_dir build/cuda12 --build_shared_lib --parallel \ - --use_cuda --cuda_version 12.6 --cuda_home $install_dir/cuda12.6 \ - --cudnn_home $install_dir/cudnn9.5 \ - --build_wheel --skip_tests \ - --cmake_generator Ninja \ - --compile_no_warning_as_error \ - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=$CUDA_ARCH \ - --cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \ - --enable_cuda_line_info - - pip install build/cuda12/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl numpy==1.26.4 - else - echo "No CUDA device found." - exit 1 - fi - popd + cd onnxruntime || exit + pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy + build_dir=build/cuda${cuda_version} + rm -rf ${build_dir}/Release/dist + sh build.sh --config Release --build_dir "${build_dir}" --build_shared_lib --parallel \ + --use_cuda --cuda_version ${cuda_version} --cuda_home "$install_dir/cuda${cuda_version}" \ + --cudnn_home "$install_dir/cudnn${cudnn_version}" \ + --build_wheel --skip_tests \ + --cmake_generator Ninja \ + --compile_no_warning_as_error \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=native \ + --cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \ + --enable_cuda_line_info + pip uninstall onnxruntime-gpu -y + pip install "${build_dir}/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl" + popd || exit } # Run profiling with NVTX. -run_nvtx_profile() -{ - pip install nvtx cuda-python==12.6.0 - +run_nvtx_profile() { + local engine="$1" # Only trace one device to avoid huge output file size. device_id=0 - envs="CUDA_VISIBLE_DEVICES=$device_id,ORT_ENABLE_CUDNN_FLASH_ATTENTION=1,LD_LIBRARY_PATH=$LD_LIBRARY_PATH" + envs="CUDA_VISIBLE_DEVICES=$device_id,ORT_ENABLE_CUDNN_FLASH_ATTENTION=1,LD_LIBRARY_PATH=$LD_LIBRARY_PATH,TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1" cuda_graph_trace=node - for engine in ort torch; do - for component in image_encoder image_decoder; do - sudo $install_dir/cuda12.6/bin/nsys profile --capture-range=nvtx --nvtx-capture='one_run' \ - --gpu-metrics-device $device_id --force-overwrite true \ - --sample process-tree --backtrace fp --stats true \ - -t cuda,cudnn,cublas,osrt,nvtx --cuda-memory-usage true --cudabacktrace all \ - --cuda-graph-trace $cuda_graph_trace \ - -e $envs,NSYS_NVTX_PROFILER_REGISTER_ONLY=0 \ - -o sam2_fp16_profile_${component}_${engine}_${cpu_or_gpu} \ - $python benchmark_sam2.py --model_type $model --engine $engine \ - --sam2_dir $sam2_dir --warm_up 1 --repeats 0 \ - --onnx_path ${onnx_dir}/${model}_${component}_fp16_gpu.onnx \ - --component $component \ - --use_gpu --dtype fp16 --enable_nvtx_profile - done + for component in image_encoder image_decoder; do + sudo "$install_dir/cuda${cuda_version}/bin/nsys" profile --capture-range=nvtx --nvtx-capture='one_run' \ + --gpu-metrics-devices $device_id --force-overwrite true \ + --sample process-tree --backtrace fp --stats true \ + -t cuda,cudnn,cublas,osrt,nvtx --cuda-memory-usage true --cudabacktrace all \ + --cuda-graph-trace "$cuda_graph_trace" \ + -e "$envs,NSYS_NVTX_PROFILER_REGISTER_ONLY=0" \ + -o "sam2_fp16_profile_${component}_${engine}_${cpu_or_gpu}" \ + $python benchmark_sam2.py --model_type "$model" --engine "$engine" \ + --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ + --onnx_path "${onnx_dir}/${model}_${component}_fp16_gpu.onnx" \ + --component "$component" \ + --use_gpu --dtype fp16 --enable_nvtx_profile done } -# Run profiling with PyTorch -run_torch_profile() { +run_ort_profile() { + export ORT_ENABLE_CUDNN_FLASH_ATTENTION=1 + rm -f onnxruntime_*.json for component in image_encoder image_decoder; do - $python benchmark_sam2.py --model_type $model --engine torch \ - --sam2_dir $sam2_dir --warm_up 1 --repeats 0 \ - --component $component \ - --use_gpu --dtype fp16 --enable_torch_profile + $python benchmark_sam2.py --model_type "$model" --engine ort \ + --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ + --onnx_path "${onnx_dir}/${model}_${component}_fp16_gpu.onnx" \ + --component "$component" \ + --use_gpu --dtype fp16 --enable_ort_profile + mv onnxruntime_profile*.json onnxruntime_$component.json done } -run_profilings() { - build_onnxruntime_gpu_for_profiling +# Run profiling with PyTorch +run_torch_profile() { + # Enable logging might could help get the code of compiled kernels. You can turn it off to reduce overhead. + export TORCH_LOGS="+inductor,+output_code" + export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 + component=image_encoder + $python benchmark_sam2.py --model_type "$model" --engine torch \ + --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ + --component "$component" \ + --torch_compile_mode max-autotune \ + --use_gpu --dtype fp16 --enable_torch_profile > "torch_${component}_compiled_code.txt" + + component=image_decoder + $python benchmark_sam2.py --model_type "$model" --engine torch \ + --sam2_dir "$sam2_dir" --warm_up 1 --repeats 0 \ + --component "$component" \ + --torch_compile_mode none \ + --use_gpu --dtype fp16 --enable_torch_profile +} +run_nvtx_profilings() { + build_onnxruntime_gpu_for_profiling rm -f *.nsys-rep *.sqlite - run_nvtx_profile + run_nvtx_profile ort + run_nvtx_profile torch +} +run_profilings() { + pip install nvtx cuda-python==${cuda_version}.0 + run_ort_profile run_torch_profile + + # NVTX profiling need to build onnxruntime-gpu from source so it is put as the last step. + run_nvtx_profilings } -profiling="${3:-false}" if [ "$profiling" = "true" ] && [ "$cpu_or_gpu" = "gpu" ]; then run_profilings fi diff --git a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py index cacad717faf9c..3533a274b9972 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py @@ -113,6 +113,14 @@ def parse_arguments(): help="Optimize onnx models for GPU", ) + parser.add_argument( + "--dynamo", + required=False, + default=False, + action="store_true", + help="Use dynamo for exporting onnx model. Only image_encoder supports dynamo right now.", + ) + parser.add_argument( "--verbose", required=False, @@ -151,8 +159,10 @@ def main(): onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output) if component == "image_encoder": if args.overwrite or not os.path.exists(onnx_model_path): - export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) - test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=False) + export_image_encoder_onnx( + sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose, args.dynamo + ) + test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=args.dynamic_batch_axes) elif component == "mask_decoder": if args.overwrite or not os.path.exists(onnx_model_path): diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py index 07ed150631f50..376e6ba7d802c 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py @@ -246,7 +246,7 @@ def test_decoder_onnx( import onnxruntime - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py index c5ce339732063..79e9297788c36 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py @@ -90,6 +90,8 @@ def export_image_encoder_onnx( onnx_model_path: str, dynamic_batch_axes: bool = False, verbose: bool = False, + dynamo: bool = False, + clear_dynamo_metadata: bool = False, ): image = random_sam2_input_image() @@ -113,17 +115,65 @@ def export_image_encoder_onnx( if not verbose: warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=UserWarning) - torch.onnx.export( - sam2_encoder, - image, - onnx_model_path, - export_params=True, - opset_version=17, - do_constant_folding=True, - input_names=["image"], - output_names=["image_features_0", "image_features_1", "image_embeddings"], - dynamic_axes=dynamic_axes, - ) + + if not dynamo: + torch.onnx.export( + sam2_encoder, + image, + onnx_model_path, + export_params=True, + opset_version=17, + do_constant_folding=True, + input_names=["image"], + output_names=["image_features_0", "image_features_1", "image_embeddings"], + dynamic_axes=dynamic_axes, + ) + else: + torch._dynamo.config.capture_scalar_outputs = True + ep = torch.export.export( + sam2_encoder, + args=(image,), + strict=False, + dynamic_shapes=[ + {0: torch.export.Dim.AUTO}, + ], + ) + + onnx_program = torch.onnx.export( + ep, + (), + opset_version=17, + input_names=["image"], + output_names=["image_features_0", "image_features_1", "image_embeddings"], + dynamo=True, + ) + onnx_program.optimize() + onnx_program.save(onnx_model_path + ".dynamo.onnx", external_data=False) + import onnx + + from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper + + onnx_model = onnx.load_model(onnx_model_path + ".dynamo.onnx", load_external_data=True) + if dynamic_batch_axes: + # Fix labels of dynamic axes since they can't be specified during Dynamo export currently + onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "batch_size" + for i in range(3): + onnx_model.graph.output[i].type.tensor_type.shape.dim[0].dim_param = "batch_size" + + onnx_model_helper = DynamoOnnxHelper(onnx_model) + onnx_model_helper.convert_constants_to_initializers() + if clear_dynamo_metadata: + onnx_model_helper.clear_metadata() + + import os + + if os.path.exists(onnx_model_path): + os.remove(onnx_model_path) + if os.path.exists(onnx_model_path + ".data"): + os.remove(onnx_model_path + ".data") + onnx_model_helper.model.save_model_to_file( + onnx_model_path, use_external_data_format=True, all_tensors_to_one_file=True, convert_attribute=True + ) print("encoder onnx model saved to", onnx_model_path) @@ -133,7 +183,7 @@ def test_image_encoder_onnx( onnx_model_path: str, dynamic_batch_axes=False, ): - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py b/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py index 56473c002d4ae..fa83e2f666d06 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py @@ -177,7 +177,7 @@ def test_mask_decoder_onnx( import onnxruntime - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py b/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py index 883c51858346c..f25e6ff23324b 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py @@ -146,7 +146,7 @@ def test_prompt_encoder_onnx( import onnxruntime - ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) model_inputs = ort_session.get_inputs() input_names = [model_inputs[i].name for i in range(len(model_inputs))] diff --git a/onnxruntime/test/qnn_ctx_gen/README.md b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md similarity index 82% rename from onnxruntime/test/qnn_ctx_gen/README.md rename to onnxruntime/test/ep_weight_sharing_ctx_gen/README.md index 97ab89d79cbd2..be1a1fe039366 100644 --- a/onnxruntime/test/qnn_ctx_gen/README.md +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md @@ -2,17 +2,19 @@ This tool provides the way to generate Onnx models that wraps QNN context binary warpt with weight sharing enabled. The options to use with the tool are listed below: -`onnxruntime_qnn_ctx_gen [options...] model_path,model_path` +`ep_weight_sharing_ctx_gen [options...] model_1_path,model_2_path` -./onnxruntime_qnn_ctx_gen -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" -C "ep.context_enable|1 ep.context_embed_mode|0" /mnt/c/model1.onnx,/mnt/c/model2.onnx +./ep_weight_sharing_ctx_gen -e qnn -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" /mnt/c/model1.onnx,/mnt/c/model2.onnx Options: - + + -e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider qnn, tensorrt, openvino, vitisai. Default is qnn. + -v: Show verbose information. -C: [session_config_entries]: Specify session configuration entries as key-value pairs: -C "| |" Refer to onnxruntime_session_options_config_keys.h for valid keys and values. - [Example] -C "ep.context_enable|1 ep.context_embed_mode|0" + [Example] -C "ep.context_enable|1 ep.context_embed_mode|0". These are set as default so can be ignored. -i: [provider_options]: Specify QNN EP specific runtime options as key value pairs. Different runtime options available are: [Usage]: -i '| |' diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc similarity index 68% rename from onnxruntime/test/qnn_ctx_gen/command_args_parser.cc rename to onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc index 24c343c7b9541..bf21d54ccde41 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc @@ -1,5 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "command_args_parser.h" @@ -29,28 +28,30 @@ namespace qnnctxgen { /*static*/ void CommandLineParser::ShowUsage() { printf( - "onnxruntime_qnn_ctx_gen [options...] model1_path,model2_path\n" - "Example: ./onnxruntime_qnn_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n" + "ep_weight_sharing_ctx_gen [options...] model1_path,model2_path\n" + "Example: ./ep_weight_sharing_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n" "Options:\n" + "\t-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn','tensorrt','openvino', 'vitisai'. " + "Default:'qnn'.\n" "\t-v: Show verbose information.\n" "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" "\t Force ep.context_enable to 1 and ep.context_embed_mode to 0. Change ep.context_file_path is not allowed." "\t [Example] -C \"ep.context_node_name_prefix|_part1\" \n" - "\t-i: Specify QNN EP specific runtime options as key value pairs. Different runtime options available are: \n" + "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [Usage]: -i '| |'\n" "\n" - "\t [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. default to HTP backend\n" - "\t [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" - "\t [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: '0', '1', '2', '3', default is '0'.\n" - "\t [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" - "\t [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" - "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" + "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. default to HTP backend\n" + "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" + "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: '0', '1', '2', '3', default is '0'.\n" + "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" - "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" - "\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" - "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" - "\t [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." + "\t [QNN only] [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" + "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '1' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -109,8 +110,22 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(TestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("o:u:i:C:vh"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("e:o:u:i:C:vh"))) != -1) { switch (ch) { + case 'e': + if (!CompareCString(optarg, ORT_TSTR("qnn"))) { + test_config.machine_config.provider_type_name = onnxruntime::kQnnExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("openvino"))) { + test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("tensorrt"))) { + test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("vitisai"))) { + test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider; + } else { + fprintf(stderr, "The execution provider is not included in this tool.\n"); + return false; + } + break; case 'v': test_config.run_config.f_verbose = true; break; @@ -162,7 +177,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])"); } - test_config.run_config.qnn_options[key] = value; + test_config.run_config.provider_options[key] = value; } break; } diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.h b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.h similarity index 100% rename from onnxruntime/test/qnn_ctx_gen/command_args_parser.h rename to onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.h diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc new file mode 100644 index 0000000000000..104cdbdfd5abc --- /dev/null +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc @@ -0,0 +1,247 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_configuration.h" +#include "command_args_parser.h" + +// onnxruntime dependencies +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +// onnx dependencies +#include "onnx/onnx_pb.h" +#include + +using namespace onnxruntime; +using ProviderOptions = std::unordered_map; + +// from the last context cache Onnx model, find the EPContext node with main_context=1, +// and get the QNN context binary file name, this context binary contains all graphs from all Onnx models +// get the max spill fill buffer size +static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, + std::string& last_ctx_bin_file, + int64_t& max_size) { + max_size = 0; + + onnx::ModelProto model; + std::ifstream onnx_file_stream(last_onnx_ctx_file, std::ios_base::binary); + model.ParseFromIstream(&onnx_file_stream); + + for (auto& node : model.graph().node()) { + if (node.op_type() == "EPContext") { + int64_t is_main_context = 0; + for (auto& attr : node.attribute()) { + if (attr.name() == "main_context") { + is_main_context = attr.i(); + } + if (attr.name() == "max_size") { + max_size = attr.i(); + } + if (attr.name() == "ep_cache_context") { + last_ctx_bin_file = attr.s(); + } + } + if (is_main_context) { + return; + } + } + } + + onnx_file_stream.close(); +} + +// Update generated context cache Onnx model to make the main EPContext node point to +// the last QNN context binary file +// Remove not used QNN context binary file, only keep the last one which contains all graphs +static void UpdateEpContextModel(const std::vector>& ep_ctx_files, + const std::string& last_qnn_ctx_binary_file_name, + int64_t max_size) { + for (auto ep_ctx_file : ep_ctx_files) { + onnx::ModelProto model; + std::ifstream onnx_file_stream(ep_ctx_file, std::ios_base::binary); + model.ParseFromIstream(&onnx_file_stream); + onnx_file_stream.close(); + + for (auto& node : *(model.mutable_graph()->mutable_node())) { + if (node.op_type() == "EPContext") { + int64_t is_main_context = 0; + std::string old_qnn_ctx_binary_file_name; + int max_size_index = 0; + int ep_context_index = 0; + for (auto i = 0; i < node.attribute_size(); ++i) { + auto& attr = node.attribute()[i]; + if (attr.name() == "main_context") { + is_main_context = attr.i(); + } + if (attr.name() == "max_size") { + max_size = attr.i(); + max_size_index = i; + } + if (attr.name() == "ep_cache_context") { + old_qnn_ctx_binary_file_name = attr.s(); + ep_context_index = 0; + } + } + if (is_main_context) { + auto path_str = ToPathString(ep_ctx_file); + auto path = std::filesystem::path(path_str); + auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); + std::remove(file_path.string().c_str()); + + node.mutable_attribute(max_size_index)->set_i(max_size); + node.mutable_attribute(ep_context_index)->set_s(last_qnn_ctx_binary_file_name); + } + } + } + + // re-write the onnx ctx file + std::ofstream onnx_file_ostream(ep_ctx_file, std::ios_base::binary); + model.SerializeToOstream(&onnx_file_ostream); + onnx_file_ostream.close(); + } +} + +#ifdef _WIN32 +int real_main(int argc, wchar_t* argv[]) { +#else +int real_main(int argc, char* argv[]) { +#endif + qnnctxgen::TestConfig test_config; + if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { + qnnctxgen::CommandLineParser::ShowUsage(); + return -1; + } + + OrtLoggingLevel logging_level = test_config.run_config.f_verbose + ? ORT_LOGGING_LEVEL_VERBOSE + : ORT_LOGGING_LEVEL_ERROR; + Ort::Env env(logging_level, "ep_weight_sharing"); + + ORT_TRY { + Ort::SessionOptions so; + so.SetLogId("ep_weight_sharing_ctx_gen_session_logger"); + // Set default session option to dump EPContext model with non-embed mode + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + + ProviderOptions provider_options; + + for (auto it : test_config.run_config.provider_options) { + provider_options[it.first] = it.second; + } + + for (auto it : test_config.run_config.session_config_entries) { + if (it.first == kOrtSessionOptionEpContextEnable && it.second != "1") { + std::cerr << "Need to enable ep context cache." << std::endl; + continue; + } + if (it.first == kOrtSessionOptionEpContextEmbedMode && it.second != "0") { + std::cerr << "Only support non-embed model for weight sharing." << std::endl; + continue; + } + if (it.first == kOrtSessionOptionEpContextFilePath) { + std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; + continue; + } + so.AddConfigEntry(it.first.c_str(), it.second.c_str()); + } + + for (auto model_path : test_config.model_file_paths) { + std::cout << "Model file path: " << ToUTF8String(model_path) << std::endl; + } + + // Generate context cache model files with QNN context binary files + // The context binary file generated later includes all graphs from previous models + { + std::string provider_name_ = test_config.machine_config.provider_type_name; + if (provider_name_ == onnxruntime::kQnnExecutionProvider) { +#ifdef USE_QNN +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + // set default QNN EP option to enable weight sharing if not set by user + const std::string enable_htp_weight_sharing = "enable_htp_weight_sharing"; + if (provider_options.find(enable_htp_weight_sharing) == provider_options.end()) { + provider_options[enable_htp_weight_sharing] = "1"; + } + so.AppendExecutionProvider("QNN", provider_options); +#else + ORT_THROW("QNN is not supported in this build\n"); +#endif + } else if (!provider_name_.empty()) { + ORT_THROW("This execution provider is not included in this tool.\n"); + } + + size_t total_file_count = test_config.model_file_paths.size(); + for (size_t i = 0; i < total_file_count; ++i) { + auto model_path = test_config.model_file_paths[i]; + std::cout << "Generating context cache model for: " << ToUTF8String(model_path) << std::endl; + if (i == total_file_count - 1) { + so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); + } + Ort::Session session(env, model_path.c_str(), so); + } + } + + std::cout << "Start to update the generated Onnx model." << std::endl; + std::vector> ep_ctx_files; + ep_ctx_files.reserve(test_config.model_file_paths.size()); + for (auto model_path : test_config.model_file_paths) { + auto pos = model_path.find_last_of(ORT_TSTR(".")); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + ORT_TSTR("_ctx.onnx"); + } else { + model_path = model_path + ORT_TSTR("_ctx.onnx"); + } + ep_ctx_files.push_back(model_path); + } + + // Get the last context binary file name + std::string last_qnn_ctx_binary_file_name; + int64_t max_size = 0; + GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); + std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; + if (last_qnn_ctx_binary_file_name.empty()) { + throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); + } + ep_ctx_files.pop_back(); + + // Update generated context cache Onnx model to make the main EPContext node point to + // the last QNN context binary file + // Remove not used QNN context binary file, only keep the last one only which contains all graphs + UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); + } + ORT_CATCH(const Ort::Exception& e) { + std::cerr << "Failed to generate context cache file: " << e.what(); + return -1; + } + + std::cout << "Generation done!"; + return 0; +} + +#ifdef _WIN32 +int wmain(int argc, wchar_t* argv[]) { +#else +int main(int argc, char* argv[]) { +#endif + int retval = -1; + ORT_TRY { + retval = real_main(argc, argv); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + fprintf(stderr, "%s\n", ex.what()); + retval = -1; + }); + } + + ::google::protobuf::ShutdownProtobufLibrary(); + + return retval; +} diff --git a/onnxruntime/test/qnn_ctx_gen/test_configuration.h b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h similarity index 75% rename from onnxruntime/test/qnn_ctx_gen/test_configuration.h rename to onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h index bf4c7061a3484..198d03211f561 100644 --- a/onnxruntime/test/qnn_ctx_gen/test_configuration.h +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h @@ -14,15 +14,20 @@ namespace onnxruntime { namespace qnnctxgen { +struct MachineConfig { + std::string provider_type_name{onnxruntime::kQnnExecutionProvider}; +}; + struct RunConfig { bool f_verbose{false}; std::unordered_map session_config_entries; - std::unordered_map qnn_options; + std::unordered_map provider_options; }; struct TestConfig { std::vector> model_file_paths; RunConfig run_config; + MachineConfig machine_config; }; } // namespace qnnctxgen diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 1b06eb55afbd2..95101c8075fc2 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -138,6 +138,7 @@ class FuseExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override { // Fuse two add into one. std::vector> result; diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index b6b915f90d99a..8f4eede76b905 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -27,6 +27,7 @@ #include "test/util/include/default_providers.h" #include "test/util/include/file_util.h" #include "core/optimizer/layout_transformation/layout_transformation.h" +#include "core/optimizer/graph_optimizer_registry.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -264,7 +265,11 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - GraphPartitioner partitioner(krm, execution_providers); + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&sess_options, + execution_providers.Get(onnxruntime::kCpuExecutionProvider), + &DefaultLoggingManager().DefaultLogger()); + GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); ASSERT_STATUS_OK( partitioner.Partition( graph, session_state.GetMutableFuncMgr(), @@ -350,8 +355,12 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&sess_options, + execution_providers.Get(onnxruntime::kCpuExecutionProvider), + &DefaultLoggingManager().DefaultLogger()); // Partition the graph - GraphPartitioner partitioner(krm, execution_providers); + GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); ASSERT_STATUS_OK(partitioner.Partition( graph, session_state.GetMutableFuncMgr(), [&cpu_allocator](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, @@ -409,8 +418,13 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&sess_options, + execution_providers.Get(onnxruntime::kCpuExecutionProvider), + &DefaultLoggingManager().DefaultLogger()); + // Partition the graph - GraphPartitioner partitioner(krm, execution_providers); + GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); ASSERT_STATUS_OK(partitioner.Partition( graph, session_state.GetMutableFuncMgr(), [&cpu_allocator](Graph& graph, bool& modified, @@ -479,7 +493,12 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, SessionState session_state(model->MainGraph(), execution_providers, tp.get(), nullptr, dtm, edlm, default_logger, profiler, sess_options); - GraphPartitioner partitioner(krm, execution_providers); + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup + auto graph_optimizer_registry = std::make_unique(&sess_options, + execution_providers.Get(onnxruntime::kCpuExecutionProvider), + &DefaultLoggingManager().DefaultLogger()); + + GraphPartitioner partitioner(krm, execution_providers, std::move(graph_optimizer_registry)); layout_transformation::TransformLayoutFunction transform_layout_fn; layout_transformation::DebugGraphFn debug_graph_fn; ASSERT_STATUS_OK( diff --git a/onnxruntime/test/framework/type_info_test.cc b/onnxruntime/test/framework/type_info_test.cc index ee787fb071d97..d8ef668bf1c7e 100644 --- a/onnxruntime/test/framework/type_info_test.cc +++ b/onnxruntime/test/framework/type_info_test.cc @@ -22,9 +22,9 @@ TEST(TypeInfoTests, TensorProto) { auto tensor_type_info = OrtTypeInfo::FromTypeProto(tensor_type.value); ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info->type); - ASSERT_NE(nullptr, tensor_type_info->data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->data->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info->tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->tensor_type_info->shape.GetDims())); } TEST(TypeInfoTests, SequenceWithTensorElement) { @@ -37,9 +37,9 @@ TEST(TypeInfoTests, SequenceWithTensorElement) { const auto& tensor_type_info = *seq_type_info->sequence_type_info->sequence_key_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); - ASSERT_NE(nullptr, tensor_type_info.data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info.tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->shape.GetDims())); } TEST(TypeInfoTests, OptionalWithTensorProto) { @@ -54,9 +54,9 @@ TEST(TypeInfoTests, OptionalWithTensorProto) { const auto& contained_type = *optional_type_info->optional_type_info->contained_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, contained_type.type); - ASSERT_NE(nullptr, contained_type.data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.data->shape.GetDims())); + ASSERT_NE(nullptr, contained_type.tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.tensor_type_info->shape.GetDims())); } #if !defined(DISABLE_ML_OPS) @@ -74,11 +74,11 @@ TEST(TypeInfoTests, MapWithTensorValue) { const auto& tensor_type_info = *map_info.map_value_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); - ASSERT_NE(nullptr, tensor_type_info.data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info.tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->shape.GetDims())); } #endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 6bfe7bc3856ba..eecff3fa4d8ff 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -174,7 +174,7 @@ static std::unique_ptr MakeSparseTensor(MLDataType data_type, cons return p_tensor; } -void BaseTester::CopyDataToTensor(gsl::span data, Tensor& dst) { +void BaseTester::CopyDataToTensor(gsl::span data, Tensor& dst) { ORT_ENFORCE(dst.SizeInBytes() >= data.size_bytes(), "Not enough space in the destination tensor"); memcpy(dst.MutableDataRaw(), data.data(), data.size_bytes()); } @@ -203,7 +203,7 @@ void BaseTester::AddSparseCooTensorData(std::vector& data, MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span indices, const ValidateOutputParams& check_params, const std::vector* dim_params) { @@ -247,7 +247,7 @@ void BaseTester::AddSparseCsrTensorData(std::vector& data, MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span inner_indices, gsl::span outer_indices, const ValidateOutputParams& check_params, diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h index 512b3402c5986..d39cc3c750dec 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/providers/base_tester.h @@ -868,7 +868,7 @@ class BaseTester { void AddShapeToTensorData(NodeArg& node_arg, gsl::span dims, const std::vector* dim_params); - void CopyDataToTensor(gsl::span data, Tensor& dst); + void CopyDataToTensor(gsl::span data, Tensor& dst); #if !defined(DISABLE_SPARSE_TENSORS) NodeArg MakeSparseNodeArg(int32_t dtype, const char* name, @@ -879,7 +879,7 @@ class BaseTester { MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span indices, const ValidateOutputParams& check_params, const std::vector* dim_params = nullptr); @@ -895,7 +895,7 @@ class BaseTester { MLDataType data_type, const char* name, gsl::span dims, - gsl::span values, + gsl::span values, gsl::span inner_indices, gsl::span outer_indices, const ValidateOutputParams& check_params, diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 6f7930f722564..1c6375ebdb0b1 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -170,11 +170,11 @@ TEST(SoftmaxOperator, ThreeAndFourDimsAxis0) { RunTest(input_vals_60, expected_vals, three_dimensions, /*opset*/ 7, /*axis*/ 0, // axis=0 is not supported by TensorRT - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); RunTest(input_vals_60, expected_vals, four_dimensions, /*opset*/ 7, /*axis*/ 0, // axis=0 is not supported by TensorRT - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); } TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis) { @@ -201,10 +201,10 @@ TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis) { 0.040478885f, 0.033857856f, 0.080346674f, 0.06199841f, 0.040481992f}; RunTest(input_vals_60, expected_vals, three_dimensions, /*opset*/ 7, /*axis*/ 1, - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); RunTest(input_vals_60, expected_vals, four_dimensions, /*opset*/ 7, /*axis*/ 2, - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider, kWebGpuExecutionProvider}); } TEST(SoftmaxOperator, ThreeAndFourDimsSecondLastAxis_opset13) { @@ -376,8 +376,9 @@ TEST(SoftmaxOperator, DimWithZero) { RunTest(x_vals, expected_vals, dimensions, /*opset*/ -1, /*axis*/ 0, {kTensorrtExecutionProvider, - kNnapiExecutionProvider, // NNAPI softmax does not support empty input - kQnnExecutionProvider} // QNN doesn't support dim 0 + kNnapiExecutionProvider, // NNAPI softmax does not support empty input + kWebGpuExecutionProvider, // WebGPU does not support dim 0 + kQnnExecutionProvider} // QNN doesn't support dim 0 ); } diff --git a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc index a5378fa3cefd7..c98d9e28b2f46 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc @@ -254,5 +254,45 @@ TEST(ConvIntegerTest, WithStride3_2D_u8u8) { test.Run(); } +TEST(ConvIntegerTest, NoXZeroPoint) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10}); + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {2, 2, + 2, 2}); + test.AddOptionalInputEdge(); + test.AddInput("w_zero_point", {}, {1}); + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {16, 20, + 28, 32}); + test.Run(); +} + +// provide optional input with empty name for w. tests that input args == 4 but the w_zero_point does not exist. +TEST(ConvIntegerTest, NoWZeroPoint) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10}); + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {2, 2, + 2, 2}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOptionalInputEdge(); + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {24, 32, + 48, 56}); + test.Run(); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index b753bc386d722..ee0aff6d26444 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -111,6 +111,7 @@ DataLayout InternalTestingExecutionProvider::GetPreferredLayout() const { std::vector> InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { // find nodes that have ops in our supported list std::unordered_set supported_static_nodes; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index d2ed8259ee974..0caa0febc2796 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -20,6 +20,7 @@ class InternalTestingExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, const IKernelLookup& /*kernel_lookup*/, + const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; common::Status Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 07843c30a61df..3dec74599abdf 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -43,6 +43,35 @@ static const std::string& GetNodeAttr(const Node& node, const std::string& attr_ return default_val; } +// from the context cache Onnx model, find the EPContext node with main_context=1, +// and get the QNN context binary file name +static void GetContextBinaryFileName(const std::string onnx_ctx_file, + std::string& last_ctx_bin_file, + const Logger& logger) { + std::shared_ptr ctx_model; + ASSERT_STATUS_OK(Model::Load(ToPathString(onnx_ctx_file), ctx_model, nullptr, logger)); + auto& ctx_graph = ctx_model->MainGraph(); + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); + if (1 == is_main_context) { + last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); + return; + } + } + } +} + +// Get context binary file name from Context model file and remove it with the context model file +void CleanUpCtxFile(std::string context_file_path) { + std::string qnn_ctx_binary_file_name; + GetContextBinaryFileName(context_file_path, qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); + + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name.c_str()), 0); + ASSERT_EQ(std::remove(context_file_path.c_str()), 0); +} + // Create a model with FusedMatMul + Add (quantized) // input1 -> Add -> Q -> DQ ---- // | @@ -123,22 +152,22 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - const std::string context_binary_file = "./qnn_context_binary_multi_partition_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_multi_partition_test.onnx"; + std::remove(context_model_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); int ep_context_node_count = 0; int non_ep_context_node_count = 0; std::shared_ptr ctx_model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); auto& ctx_graph = ctx_model->MainGraph(); for (auto& node : ctx_graph.Nodes()) { if (node.OpType() == "EPContext") { @@ -156,7 +185,7 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { Ort::SessionOptions so2; // context file path is required if it's non-embed mode and the model is loaded from memory - so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so2.AppendExecutionProvider("QNN", provider_options); std::string ctx_model_data; @@ -164,7 +193,7 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { Ort::Session session2(*ort_env, ctx_model_data.data(), ctx_model_data.size(), so2); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary @@ -237,7 +266,7 @@ void EpCtxCpuNodeWithExternalIniFileTestBody(bool expect_external_ini_file) { // clean up ASSERT_EQ(std::remove(model_with_ext.c_str()), 0); ASSERT_EQ(std::remove(model_ext_file_full_path.c_str()), 0); - ASSERT_EQ(std::remove(ep_context_model_file.c_str()), 0); + CleanUpCtxFile(ep_context_model_file); } // Set the external initializer size threshold to 1024 so FusedMatMul (which fallback on CPU) @@ -333,7 +362,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx"; - const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin"; + const std::string ep_context_binary_file = "./ep_context_no_over_write_QNN_10880527342279992768_1_0.bin"; std::remove(ep_context_onnx_file.c_str()); Ort::SessionOptions so; @@ -444,21 +473,21 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; + std::remove(context_model_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Generate context cache model from the ONNX models with 2 inputs. @@ -481,26 +510,26 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + const std::string context_model_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); auto inputs = model->MainGraph().GetInputs(); EXPECT_TRUE(inputs.size() == 2); EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { @@ -519,20 +548,20 @@ TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + const std::string context_model_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AddConfigEntry(kOrtSessionOptionEpContextNodeNamePrefix, node_name_prefix.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); for (auto& node : model->MainGraph().Nodes()) { if (node.OpType() == "EPContext") { EXPECT_TRUE(node.Name().find(node_name_prefix) != std::string::npos); @@ -540,7 +569,7 @@ TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { } // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Run QDQ model on HTP 3 times @@ -554,12 +583,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_binary_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_test.onnx"; + std::remove(context_model_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); const std::string op_type = "Atan"; @@ -577,9 +606,11 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { session_option_pairs); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // 2nd run directly loads and run from Qnn context cache model + std::unordered_map session_option_pairs2; + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, @@ -587,9 +618,10 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file); + context_model_file, + session_option_pairs2); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Run QDQ model on HTP 3 times @@ -604,7 +636,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; const std::string context_binary_file = "./testdata/qnn_context_cache_non_embed.onnx"; - std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); @@ -686,7 +718,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { #endif provider_options["offload_graph_io_quantization"] = "0"; const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; std::remove(context_binary_file.c_str()); std::remove(context_bin.string().c_str()); @@ -828,6 +860,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, "./qnn_context_not_exist.onnx")); RunOptions run_options; run_options.run_tag = so.session_logid; @@ -841,7 +874,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); @@ -854,6 +887,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, "./test_ctx.onnx")); RunOptions run_options; run_options.run_tag = so.session_logid; @@ -867,7 +901,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); @@ -884,12 +918,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_2inputs_test.onnx"; + std::remove(context_model_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); @@ -908,9 +942,11 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { session_option_pairs); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // 2nd run directly loads and run from Qnn context cache model + std::unordered_map session_option_pairs2; + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), provider_options, @@ -918,9 +954,10 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file); + context_model_file, + session_option_pairs2); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Context binary only contains a single QNN graph, generated context cache model (detached mode) only has 1 EPContext node @@ -936,14 +973,14 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; + std::remove(context_model_file.c_str()); std::remove(context_bin.string().c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); @@ -962,7 +999,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName session_option_pairs); // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // Check the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_bin)); @@ -990,18 +1027,19 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str())); RunOptions run_options; run_options.run_tag = so.session_logid; InferenceSessionWrapper session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::OK); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(context_model_file.c_str()), 0); ASSERT_EQ(std::remove(context_bin.string().c_str()), 0); } @@ -1053,44 +1091,20 @@ static void CreateQdqModel(const std::string& model_file_name, const Logger& log static void DumpModelWithSharedCtx(const ProviderOptions& provider_options, const std::string& onnx_model_path1, const std::string& onnx_model_path2) { - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); - RunOptions run_options; - run_options.run_tag = so.session_logid; - - auto qnn_ep = QnnExecutionProviderWithOptions(provider_options, &so); - std::shared_ptr qnn_ep_shared(std::move(qnn_ep)); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); - InferenceSessionWrapper session_object1{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object1.RegisterExecutionProvider(qnn_ep_shared)); - ASSERT_STATUS_OK(session_object1.Load(ToPathString(onnx_model_path1))); - ASSERT_STATUS_OK(session_object1.Initialize()); + so.AppendExecutionProvider("QNN", provider_options); - InferenceSessionWrapper session_object2{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(qnn_ep_shared)); - ASSERT_STATUS_OK(session_object2.Load(ToPathString(onnx_model_path2))); - ASSERT_STATUS_OK(session_object2.Initialize()); -} + // Create 2 sessions to generate context binary models, the 1st session will share the QnnBackendManager + // to the 2nd session, so graphs from these 2 models are all included in the 2nd context binary + Ort::Session session1(*ort_env, ToPathString(onnx_model_path1).c_str(), so); -// from the last context ache Onnx model, find the EPContext node with main_context=1, -// and get the QNN context binary file name, thie context binary contains all graphs from all Onnx models -static void GetLastContextBinaryFileName(const std::string last_onnx_ctx_file, - std::string& last_ctx_bin_file, - const Logger& logger) { - std::shared_ptr ctx_model; - ASSERT_STATUS_OK(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, logger)); - auto& ctx_graph = ctx_model->MainGraph(); - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - if (1 == is_main_context) { - last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); - return; - } - } - } + so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); + Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); } // Update generated context cache Onnx model to make the main EPContext node point to @@ -1167,15 +1181,21 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) { for (auto model_path : onnx_model_paths) { CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); - ctx_model_paths.push_back(model_path + "_ctx.onnx"); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); } DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); - // Get the last context binary file name + // Get the last context binary file name, the latest context binary file holds all graphs generated from all models std::string last_qnn_ctx_binary_file_name; - GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); + GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty()); // Update generated context cache Onnx model to make the main EPContext node point to @@ -1265,15 +1285,21 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { for (auto model_path : onnx_model_paths) { CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); - ctx_model_paths.push_back(model_path + "_ctx.onnx"); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); } DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); // Get the last context binary file name std::string last_qnn_ctx_binary_file_name; - GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); + GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty()); // Update generated context cache Onnx model to make the main EPContext node point to @@ -1336,6 +1362,69 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { } std::remove(last_qnn_ctx_binary_file_name.c_str()); } + +// For Ort sessions to generate the context binary, with session option ep.share_ep_contexts enabled +// Ort sessions will share the QnnBackendManager, so that all graphs from all models compile into the same Qnn context +TEST_F(QnnHTPBackendTests, QnnContextGenWeightSharingSessionAPI) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["offload_graph_io_quantization"] = "0"; + + // Create QDQ models + std::vector onnx_model_paths{"./weight_share1.onnx", "./weight_share2.onnx"}; + std::vector ctx_model_paths; + for (auto model_path : onnx_model_paths) { + CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); + } + + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session1(*ort_env, ToPathString(onnx_model_paths[0]).c_str(), so); + std::string qnn_ctx_binary_file_name1; + GetContextBinaryFileName(ctx_model_paths[0], qnn_ctx_binary_file_name1, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name1.empty()); + + // Tell the EP stop share the QnnBackendManager from this session then on + so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); + Ort::Session session2(*ort_env, ToPathString(onnx_model_paths[1]).c_str(), so); + std::string qnn_ctx_binary_file_name2; + GetContextBinaryFileName(ctx_model_paths[1], qnn_ctx_binary_file_name2, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name2.empty()); + + auto file_size_1 = std::filesystem::file_size(qnn_ctx_binary_file_name1); + auto file_size_2 = std::filesystem::file_size(qnn_ctx_binary_file_name2); + EXPECT_TRUE(file_size_2 > file_size_1); + + // clean up + for (auto model_path : onnx_model_paths) { + ASSERT_EQ(std::remove(model_path.c_str()), 0); + } + for (auto ctx_model_path : ctx_model_paths) { + ASSERT_EQ(std::remove(ctx_model_path.c_str()), 0); + } + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name1.c_str()), 0); + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name2.c_str()), 0); +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index e2deccc4fff0f..2361e179d1cf1 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -14,6 +14,7 @@ #include "core/framework/compute_capability.h" #include "core/graph/graph.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/optimizer/graph_optimizer_registry.h" namespace onnxruntime { namespace test { @@ -279,9 +280,10 @@ static BackendSupport GetHTPSupport(const onnxruntime::logging::Logger& logger) onnxruntime::GraphViewer graph_viewer(graph); std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( {{"backend_path", "QnnHtp.dll"}, {"offload_graph_io_quantization", "0"}}); + GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability qnn_ep->SetLogger(&logger); - auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, nullptr); + auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; } @@ -342,9 +344,10 @@ static BackendSupport GetCPUSupport(const onnxruntime::logging::Logger& logger) onnxruntime::GraphViewer graph_viewer(graph); std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( {{"backend_path", "QnnCpu.dll"}, {"offload_graph_io_quantization", "0"}}); + GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability qnn_ep->SetLogger(&logger); - auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, nullptr); + auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; } diff --git a/onnxruntime/test/python/quantization/test_get_qdq_config.py b/onnxruntime/test/python/quantization/test_get_qdq_config.py index 25f058d8f6eac..4a71b3694822c 100644 --- a/onnxruntime/test/python/quantization/test_get_qdq_config.py +++ b/onnxruntime/test/python/quantization/test_get_qdq_config.py @@ -156,6 +156,62 @@ def should_exclude_node_(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: self.assertTrue(bool(expected_excluded_nodes)) self.assertEqual(set(qdq_config.nodes_to_exclude), expected_excluded_nodes) + def test_op_types_to_quantize(self): + """ + Test that get_qdq_config() returns a config that sets the op_types_to_quantize arg. + """ + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # No op_types_to_quantize arg means all ops are quantized. + qdq_config = get_qdq_config(float_model, data_reader, op_types_to_quantize=None) + self.assertEqual(set(qdq_config.op_types_to_quantize), {"Add"}) + + # specify custom op_types_to_quantize arg. + qdq_config = get_qdq_config(float_model, data_reader, op_types_to_quantize=["Mul"]) + self.assertEqual(set(qdq_config.op_types_to_quantize), {"Mul"}) + + # exclude op_type indirectly by specifying nodes_to_exclude arg. + qdq_config = get_qdq_config( + float_model, + data_reader, + nodes_to_exclude=[node.name for node in float_model.graph.node if node.op_type == "Add"], + ) + self.assertEqual(set(qdq_config.op_types_to_quantize), set()) + + def test_calibration_providers(self): + """ + Test that get_qdq_config() returns a config that sets the calibration providers arg. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + qdq_config = get_qdq_config( + float_model, + data_reader, + calibration_providers=["CPUExecutionProvider"], + ) + self.assertEqual(qdq_config.calibration_providers, ["CPUExecutionProvider"]) + def test_external_data(self): """ Test that get_qdq_config() returns a config that enables external data diff --git a/onnxruntime/test/qnn_ctx_gen/main.cc b/onnxruntime/test/qnn_ctx_gen/main.cc deleted file mode 100644 index bb5007b40b072..0000000000000 --- a/onnxruntime/test/qnn_ctx_gen/main.cc +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// onnxruntime dependencies -#include "test_configuration.h" -#include -#include -#include -#include "command_args_parser.h" -#include - -#include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/inference_session.h" -#include "core/session/ort_env.h" -#include "core/providers/provider_factory_creators.h" -#include "core/common/logging/sinks/clog_sink.h" - -#include "core/graph/model.h" -#include "core/session/environment.h" -#include "core/common/logging/logging.h" - -using namespace onnxruntime; -const OrtApi* g_ort = NULL; -std::unique_ptr ort_env; - -static void CheckStatus(const Status& status) { - if (status.Code() != common::StatusCode::OK) { - std::string msg = status.ErrorMessage(); - throw Ort::Exception(std::move(msg), OrtErrorCode::ORT_FAIL); - } -} - -static int64_t GetNodeAttr(const Node& node, const std::string& attr_name, int64_t default_val) { - const auto& attributes = node.GetAttributes(); - if (auto entry = attributes.find(attr_name); entry != attributes.end()) { - return entry->second.i(); - } - - return default_val; -} - -static const std::string& GetNodeAttr(const Node& node, const std::string& attr_name, const std::string& default_val) { - const auto& attributes = node.GetAttributes(); - if (auto entry = attributes.find(attr_name); entry != attributes.end()) { - return entry->second.s(); - } - - return default_val; -} - -// from the last context cache Onnx model, find the EPContext node with main_context=1, -// and get the QNN context binary file name, this context binary contains all graphs from all Onnx models -// get the max spill fill buffer size -static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, - std::string& last_ctx_bin_file, - int64_t& max_size) { - max_size = 0; - std::shared_ptr ctx_model; - CheckStatus(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, - (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); - auto& ctx_graph = ctx_model->MainGraph(); - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - max_size = GetNodeAttr(node, "max_size", static_cast(0)); - if (1 == is_main_context) { - last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); - return; - } - } - } -} - -// Update generated context cache Onnx model to make the main EPContext node point to -// the last QNN context binary file -// Remove not used QNN context binary file, only keep the last one which contains all graphs -static void UpdateEpContextModel(const std::vector>& ep_ctx_files, - const std::string& last_qnn_ctx_binary_file_name, - int64_t max_size) { - for (auto ep_ctx_file : ep_ctx_files) { - std::shared_ptr ctx_model; - auto path_str = ToPathString(ep_ctx_file); - CheckStatus(Model::Load(path_str, ctx_model, nullptr, - (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); - auto& ctx_graph = ctx_model->MainGraph(); - GraphViewer graph_viewer(ctx_graph); - auto path = std::filesystem::path(path_str); - - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - if (1 == is_main_context) { - std::string old_qnn_ctx_binary_file_name = GetNodeAttr(node, "ep_cache_context", ""); - auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); - std::remove(file_path.string().c_str()); - node.ClearAttribute("ep_cache_context"); - node.AddAttribute("ep_cache_context", last_qnn_ctx_binary_file_name); - node.ClearAttribute("max_size"); - node.AddAttribute("max_size", max_size); - } - } - } - std::remove(ToUTF8String(ep_ctx_file).c_str()); - CheckStatus(Model::Save(*ctx_model.get(), ToPathString(ep_ctx_file))); - } -} - -#ifdef _WIN32 -int real_main(int argc, wchar_t* argv[]) { -#else -int real_main(int argc, char* argv[]) { -#endif - g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); - qnnctxgen::TestConfig test_config; - if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { - qnnctxgen::CommandLineParser::ShowUsage(); - return -1; - } - - { - bool failed = false; - ORT_TRY { - OrtLoggingLevel logging_level = test_config.run_config.f_verbose - ? ORT_LOGGING_LEVEL_VERBOSE - : ORT_LOGGING_LEVEL_WARNING; - - ort_env = std::make_unique(logging_level, "Default"); - } - ORT_CATCH(const Ort::Exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "Error creating environment. Error-> %s \n", e.what()); - failed = true; - }); - } - - if (failed) - return -1; - } - - ORT_TRY { - SessionOptions so; - so.session_logid = "qnn_ctx_gen_session_logger"; - // Set default session option to dump QNN context model with non-embed mode - CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); - CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); - RunOptions run_options; - run_options.run_tag = so.session_logid; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - // set default QNN EP option to enable weight sharing - provider_options["enable_htp_weight_sharing"] = "1"; - - for (auto it : test_config.run_config.qnn_options) { - provider_options[it.first] = it.second; - } - - for (auto it : test_config.run_config.session_config_entries) { - if (it.first == kOrtSessionOptionEpContextEnable && it.second != "1") { - std::cerr << "Need to enable ep context cache." << std::endl; - continue; - } - if (it.first == kOrtSessionOptionEpContextEmbedMode && it.second != "0") { - std::cerr << "Only support non-embed model for weight sharing." << std::endl; - continue; - } - if (it.first == kOrtSessionOptionEpContextFilePath) { - std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; - continue; - } - CheckStatus(so.config_options.AddConfigEntry(it.first.c_str(), it.second.c_str())); - } - - for (auto model_path : test_config.model_file_paths) { - std::cout << "Model file path: " << ToUTF8String(model_path) << std::endl; - } - - // Generate context cache model files with QNN context binary files - // The context binary file generated later includes all graphs from previous models - { - auto ep = QNNProviderFactoryCreator::Create(provider_options, &so)->CreateProvider(); - std::shared_ptr qnn_ep(std::move(ep)); - - for (auto model_path : test_config.model_file_paths) { - std::cout << "Generate context cache model for: " << ToUTF8String(model_path) << std::endl; - InferenceSession session_object{so, ((OrtEnv*)*ort_env.get())->GetEnvironment()}; - CheckStatus(session_object.RegisterExecutionProvider(qnn_ep)); - CheckStatus(session_object.Load(ToPathString(model_path))); - CheckStatus(session_object.Initialize()); - } - } - - std::cout << "Start to update the generated Onnx model." << std::endl; - std::vector> ep_ctx_files; - ep_ctx_files.reserve(test_config.model_file_paths.size()); - for (auto model_path : test_config.model_file_paths) { - ep_ctx_files.push_back(model_path + ORT_TSTR("_ctx.onnx")); - } - - // Get the last context binary file name - std::string last_qnn_ctx_binary_file_name; - int64_t max_size = 0; - GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); - std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; - if (last_qnn_ctx_binary_file_name.empty()) { - throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); - } - ep_ctx_files.pop_back(); - - // Update generated context cache Onnx model to make the main EPContext node point to - // the last QNN context binary file - // Remove not used QNN context binary file, only keep the last one which contains all graphs - UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); - } - ORT_CATCH(const Ort::Exception& e) { - fprintf(stderr, "Failed to generate context cache file: %s \n", e.what()); - - ort_env.reset(); - return -1; - } - - ort_env.reset(); - - return 0; -} - -#ifdef _WIN32 -int wmain(int argc, wchar_t* argv[]) { -#else -int main(int argc, char* argv[]) { -#endif - int retval = -1; - ORT_TRY { - retval = real_main(argc, argv); - } - ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "%s\n", ex.what()); - retval = -1; - }); - } - - ::google::protobuf::ShutdownProtobufLibrary(); - - return retval; -} diff --git a/onnxruntime/test/shared_lib/custom_op_utils.cc b/onnxruntime/test/shared_lib/custom_op_utils.cc index bf7efacdbb505..a624479bcd00b 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.cc +++ b/onnxruntime/test/shared_lib/custom_op_utils.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "gtest/gtest.h" #include "custom_op_utils.h" @@ -639,3 +640,22 @@ void StandaloneCustomKernel::Compute(OrtKernelContext* context) { StandaloneCustomKernel::~StandaloneCustomKernel() { } + +OrtStatusPtr CustomCastKernel::ComputeV2(OrtKernelContext* context) { + Ort::KernelContext ctx(context); + + auto in = ctx.GetInput(0); + std::vector shape = in.GetTensorTypeAndShapeInfo().GetShape(); + int64_t num_elements = std::accumulate(shape.cbegin(), shape.cend(), int64_t(1), std::multiplies()); + + // CustomCast::GetInputType constraint ensures we only get float input + const float* data = in.GetTensorData(); + double* out_data = ctx.GetOutput(0, shape).GetTensorMutableData(); + gsl::span input_span(data, num_elements); + gsl::span output_span(out_data, num_elements); + + std::transform(input_span.begin(), input_span.end(), output_span.begin(), + [](float val) { return static_cast(val); }); + + return nullptr; +} diff --git a/onnxruntime/test/shared_lib/custom_op_utils.h b/onnxruntime/test/shared_lib/custom_op_utils.h index e11540aaa5691..424c2e2fe3a08 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.h +++ b/onnxruntime/test/shared_lib/custom_op_utils.h @@ -8,12 +8,6 @@ #include #endif -struct Input { - const char* name = nullptr; - std::vector dims; - std::vector values; -}; - struct MyCustomKernel { MyCustomKernel(const OrtApi& ort_api, const OrtKernelInfo* /*info*/) : ort_(ort_api) { @@ -464,4 +458,63 @@ struct MulTopOpFloat16 : Ort::CustomOpBase OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL; } -}; \ No newline at end of file +}; + +// +// Example overriding an operator where type inference is required for the output so kernel matching works correctly +// +struct CustomCastKernel { + CustomCastKernel(const OrtApi& /*ort_api*/, const OrtKernelInfo* /*info*/) + /*: ort_(ort_api)*/ { + } + + OrtStatusPtr ComputeV2(OrtKernelContext* context); + + private: + // const OrtApi& ort_; +}; + +// Custom Cast op that takes float input and converts based on 'to' attribute. +// Example implementation only supports cast to double. +struct CustomCast : Ort::CustomOpBase { + explicit CustomCast(const char* provider) : provider_(provider) { + // if overriding an ONNX op you need to set the opset versions you are overriding + start_ver_ = 7; // should match minimum ONNX schema you implement + // end_ver_ = ...; should match maximum ONNX schema you implement or unset for unlimited. + } + + // static method used by Ort::CustomOpBase::SetShapeInferFn + static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context) { + auto shape = context.GetInputShape(0); + + // infer output type based on 'to'. + auto to = context.GetAttrInt("to"); + if (to != ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + return Ort::Status("Unexpected type", ORT_INVALID_ARGUMENT).release(); + } + + context.SetOutputShape(0, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE); + return nullptr; + } + + OrtStatusPtr CreateKernelV2(const OrtApi& api, const OrtKernelInfo* info, void** op_kernel) const { + Ort::ConstKernelInfo ki(info); + *op_kernel = new CustomCastKernel(api, info); + return nullptr; + }; + + const char* GetName() const { return "Cast"; }; + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { + // example only accepts float input + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + size_t GetOutputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + private: + const char* provider_{"CPUExecutionProvider"}; +}; diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index ca9ca0f82a25a..b517ba7032886 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1,17 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include -#include -#include +#include #include +#include +#include +#include #include -#include +#include #include +#include #include +#include + #include "gtest/gtest.h" #include "gmock/gmock.h" @@ -25,13 +27,13 @@ #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/util/thread_utils.h" -#include "onnxruntime_config.h" -#include "providers.h" -#include "test_allocator.h" -#include "test_fixture.h" -#include "utils.h" -#include "custom_op_utils.h" -#include +#include "test/shared_lib/custom_op_utils.h" +#include "test/shared_lib/test_fixture.h" +#include "test/shared_lib/utils.h" +#include "test/util/include/providers.h" +#include "test/util/include/test_allocator.h" + +#include "onnxruntime_config.h" // generated file in build output dir #ifdef _WIN32 #include @@ -63,48 +65,6 @@ constexpr size_t countof(T (&)[N]) { return N; } extern std::unique_ptr ort_env; -template -void RunSession(OrtAllocator* allocator, Ort::Session& session_object, - const std::vector& inputs, - const char* output_name, - const std::vector& dims_y, - const std::vector& values_y, - Ort::Value* output_tensor) { - std::vector ort_inputs; - std::vector input_names; - for (size_t i = 0; i < inputs.size(); i++) { - input_names.emplace_back(inputs[i].name); - ort_inputs.emplace_back( - Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), - inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); - } - - std::vector ort_outputs; - if (output_tensor) - session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), - &output_name, output_tensor, 1); - else { - ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), - &output_name, 1); - ASSERT_EQ(ort_outputs.size(), 1u); - output_tensor = &ort_outputs[0]; - } - - auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); - ASSERT_EQ(type_info.GetShape(), dims_y); - size_t total_len = type_info.GetElementCount(); - ASSERT_EQ(values_y.size(), total_len); - - OutT* f = output_tensor->GetTensorMutableData(); - for (size_t i = 0; i != total_len; ++i) { - if constexpr (std::is_same::value || std::is_same::value) { - ASSERT_NEAR(values_y[i], f[i], 1e-3); - } else { - ASSERT_EQ(values_y[i], f[i]); - } - } -} - #ifdef USE_DML struct DmlObjects { ComPtr d3d12_device; @@ -300,12 +260,12 @@ Ort::Value CreateTensorValueFromExistingD3DResource( #endif -template +template > static void TestInference(Ort::Env& env, const std::basic_string& model_uri, const std::vector& inputs, const char* output_name, const std::vector& expected_dims_y, - const std::vector& expected_values_y, + const std::vector& expected_values_y, int provider_type, OrtCustomOpDomain* custom_op_domain_ptr, const ORTCHAR_T* custom_op_library_filename, @@ -362,26 +322,26 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod auto default_allocator = std::make_unique(); // without preallocated output tensor - RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims_y, - expected_values_y, - nullptr); + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + nullptr); // with preallocated output tensor - Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), - expected_dims_y.data(), expected_dims_y.size()); + Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), + expected_dims_y.data(), expected_dims_y.size()); // test it twice for (int i = 0; i != 2; ++i) - RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims_y, - expected_values_y, - &value_y); + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + &value_y); } } @@ -450,8 +410,8 @@ class CApiTestWithProvider : public testing::Test, public ::testing::WithParamIn TEST_P(CApiTestWithProvider, simple) { // simple inference test // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -621,8 +581,8 @@ TEST(CApiTest, SparseInputModel) { TEST(CApiTest, custom_op_handler) { std::cout << "Running custom op inference" << std::endl; - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -657,8 +617,8 @@ TEST(CApiTest, custom_op_handler) { TEST(CApiTest, custom_op_set_input_memory_type) { std::cout << "Running custom op inference" << std::endl; - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -687,8 +647,8 @@ TEST(CApiTest, custom_op_set_input_memory_type) { #if !defined(ORT_MINIMAL_BUILD) TEST(CApiTest, StandaloneOpHandler) { - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -811,7 +771,7 @@ TEST(CApiTest, test_enable_ort_customops_stringlower) { // test custom op which accepts float and double as inputs TEST(CApiTest, varied_input_custom_op_handler) { - std::vector inputs(2); + std::vector> inputs(2); inputs[0].name = "X"; inputs[0].dims = {3}; inputs[0].values = {2.0f, 3.0f, 4.0f}; @@ -1422,8 +1382,8 @@ TEST(CApiTest, custom_op_with_attributes_handler) { TEST(CApiTest, RegisterCustomOpForCPUAndCUDA) { std::cout << "Tests registration of a custom op of the same name for both CPU and CUDA EPs" << std::endl; - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -1531,7 +1491,7 @@ TEST(CApiTest, test_custom_op_openvino_wrapper_library) { // The custom op extracts the serialized .xml/.bin bytes and creates an in-memory OpenVINO model // during kernel creation. The custom op is passed an image of a hand-drawn "1" as an input during computation, which // is then inferenced using OpenVINO C++ APIs. - std::vector inputs(1); + std::vector> inputs(1); inputs[0].name = "Input3"; inputs[0].dims = {1, 1, 28, 28}; @@ -1630,7 +1590,7 @@ TEST(CApiTest, test_custom_op_library) { #endif std::cout << "Running inference using custom op shared library" << std::endl; - std::vector inputs(2); + std::vector> inputs(2); inputs[0].name = "input_1"; inputs[0].dims = {3, 5}; inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, @@ -1682,7 +1642,7 @@ TEST(CApiTest, DISABLED_test_custom_op_shape_infer_attr) { #else TEST(CApiTest, test_custom_op_shape_infer_attr) { #endif - std::vector inputs(1); + std::vector> inputs(1); inputs[0].name = "input_0"; inputs[0].dims = {5}; inputs[0].values = {1.f, 2.f, 3.f, 4.f, 5.f}; @@ -1715,7 +1675,7 @@ TEST(CApiTest, test_custom_op_library_copy_variadic) { #endif std::cout << "Running inference using custom op shared library" << std::endl; - std::vector inputs(2); + std::vector> inputs(2); inputs[0].name = "input_0"; inputs[0].dims = {15}; inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, @@ -1869,8 +1829,8 @@ void PrepareModule() { TEST(CApiTest, test_pyop) { std::call_once(my_module_flag, PrepareModule); - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1882,8 +1842,8 @@ TEST(CApiTest, test_pyop) { TEST(CApiTest, test_pyop_multi) { std::call_once(my_module_flag, PrepareModule); - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1895,8 +1855,8 @@ TEST(CApiTest, test_pyop_multi) { TEST(CApiTest, test_pyop_kwarg) { std::call_once(my_module_flag, PrepareModule); - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1920,7 +1880,7 @@ TEST(ReducedOpsBuildTest, test_excluded_ops) { // In reduced ops build, test a model containing ops not included in required_ops.config cannot be loaded. // See onnxruntime/test/testdata/reduced_build_test.readme.txt for more details of the setup constexpr PATH_TYPE model_uri = TSTR("testdata/reduced_build_test.onnx_model_with_excluded_ops"); - std::vector inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}}; + std::vector> inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}}; std::vector expected_dims_y = {3}; std::vector expected_values_y = {0.1f, 0.1f, 0.1f}; bool failed = false; @@ -3322,8 +3282,8 @@ TEST(CApiTest, TestSharedAllocators) { OrtEnv* env_ptr = (OrtEnv*)(*ort_env); // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -3509,8 +3469,8 @@ TEST(CApiTest, TestSharedAllocators) { TEST(CApiTest, TestSharingOfInitializerAndItsPrepackedVersion) { // simple inference test // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -3905,8 +3865,8 @@ TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) { // simple inference test // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -4845,4 +4805,32 @@ TEST(CApiTest, GenerateNodeStatsFile) { output_names, 1); } -#endif \ No newline at end of file +#endif + +// Test that creates a custom Cast kernel which requires type inference of the output type to work. +// Also demonstrates overriding an ONNX operator as we register the custom op in the ONNX domain. +TEST(CApiTest, custom_cast) { + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = "input"; + input.dims = {3, 4}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 4}; + std::vector expected_values_y = {1.0, 2.0, 3.0, 4.0, + -1.0, -2.0, -3.0, -4.0, + 1.0, 2.0, 3.0, 4.0}; + + CustomCast custom_op{onnxruntime::kCpuExecutionProvider}; + + Ort::CustomOpDomain custom_op_domain(""); // onnx domain is empty string + custom_op_domain.Add(&custom_op); + + // model with Cast from ONNX test data + TestInference(*ort_env, TSTR("testdata/cast_float_to_double.onnx"), + inputs, "output", expected_dims_y, expected_values_y, 0, + custom_op_domain, nullptr); +} diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc new file mode 100644 index 0000000000000..9807fcca06ed4 --- /dev/null +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -0,0 +1,701 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/common/narrow.h" +#include "core/graph/constants.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_lite_custom_op.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +#include "test/shared_lib/test_fixture.h" +#include "test/shared_lib/utils.h" +#include "test/util/include/test_allocator.h" + +#include "onnxruntime_config.h" // generated file in build output dir + +extern std::unique_ptr ort_env; + +using namespace Ort; + +namespace { + +Ort::Session CreateSession(Ort::Env& env, + Model& graph_api_model, + Ort::SessionOptions* session_options_for_test = nullptr) { + Ort::SessionOptions default_session_options; + Ort::SessionOptions& session_options = session_options_for_test ? *session_options_for_test + : default_session_options; + + // Set this to save the model if you want to debug. + // session_options.SetOptimizedModelFilePath(ORT_TSTR("model_builder_output.onnx")); + + Ort::Session session(env, graph_api_model, session_options); + + // Session should not require the model to stay alive so free it now to validate. + graph_api_model = Model(nullptr); + + return session; +} + +template +void TestInference(Ort::Session& session, + const std::vector>& inputs, + const char* output_name, + const std::vector& expected_dims, + const std::vector& expected_values) { + auto default_allocator = std::make_unique(); + + // without preallocated output tensor + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims, + expected_values, + nullptr); +} + +// Create OrtNode using the C API +OrtNode* CreateNode(const OrtModelEditorApi& api, + const char* operator_name, const char* node_name, + const gsl::span input_names, + const gsl::span output_names, + const gsl::span attributes = {}, + const char* domain_name = onnxruntime::kOnnxDomain) { + OrtNode* node = nullptr; + Ort::ThrowOnError(api.CreateNode(operator_name, domain_name, node_name, + input_names.data(), input_names.size(), + output_names.data(), output_names.size(), + attributes.data(), attributes.size(), + &node)); + return node; +} + +// convenience func to convert initalizer lists to gsl::span +OrtNode* CreateNode(const OrtModelEditorApi& api, + const char* operator_name, const char* node_name, + const std::initializer_list input_names, + const std::initializer_list output_names, + const std::initializer_list attributes = {}, + const char* domain_name = onnxruntime::kOnnxDomain) { + std::vector inputs(input_names); + std::vector outputs(output_names); + std::vector attrs(attributes); + return CreateNode(api, operator_name, node_name, inputs, outputs, attrs, domain_name); +} +} // namespace + +struct TestAllocator : public OrtAllocator { + TestAllocator() { + version = ORT_API_VERSION; + Info = [](const struct OrtAllocator* this_ptr) -> const struct OrtMemoryInfo* { + auto* test_allocator = static_cast(this_ptr); + return test_allocator->memory_info; + }; + + Free = [](struct OrtAllocator* allocator, void* p) -> void { + auto* test_allocator = static_cast(allocator); + // find the matching pointer and remove it + auto it = std::find_if(test_allocator->weights.begin(), test_allocator->weights.end(), + [p](const std::unique_ptr>& v) { return v->data() == p; }); + if (it == test_allocator->weights.end()) { + throw std::runtime_error("Free called with unknown pointer"); + } + + test_allocator->weights.erase(it); + }; + + Alloc = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { + throw std::runtime_error("This should not be used"); + }; + + Reserve = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { + throw std::runtime_error("This should not be used"); + }; + } + + // initializers that are used directly by the model. as there's no copy they must remain valid. + // we store them in the test allocator so we can validate that Free is called + std::vector>> weights; + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, + OrtMemType::OrtMemTypeDefault); +}; + +// Test the ModelEditorAPI C api +// Uses the ORT C++ api for the rest for simplicity +TEST(ModelEditorAPITest, Basic_CApi) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + TestAllocator deleter; + + // return void so we can use ASSERT_* in the lambda + const auto build_model = [&](bool use_constant_node, OrtModel*& model) -> void { + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + // + // Create OrtModel with a Gemm. X input is 3x4, Y input is 4x8, Z output is 3x8. + // X is model input. Y is initializer. + // Set the alpha attribute of the Gemm node to 2.0 to test attribute handling. + // + + // model input + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector input_dims = {3, 4}; + // can use api.SetSymbolicDimensions to set symbolic dimensions. + // the input array should have the same rank as the call to SetDimensions. + // e.g. call SetDimensions with {-1, 3, 2} and SetSymbolicDimensions with {"N", nullptr, nullptr} to create + // a shape of {"N", 3, 2} + + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, input_dims.data(), input_dims.size())); + + OrtTypeInfo* input_type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &input_type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); // input_type_info took a copy + + // create ValueInfo and release the type info as CreateValueInfo takes a copy. + OrtValueInfo* input_value_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("X", input_type_info, &input_value_info)); + api.ReleaseTypeInfo(input_type_info); // input_value_info took a copy + tensor_type_info = nullptr; + + // model outputs + OrtTypeInfo* output_type_info = nullptr; + std::vector output_dims = {3, 8}; + + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, output_dims.data(), output_dims.size())); + + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &output_type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); // input_type_info took a copy + + OrtValueInfo* output_value_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("Z", output_type_info, &output_value_info)); + api.ReleaseTypeInfo(output_type_info); + + std::vector graph_inputs = {input_value_info}; + std::vector graph_outputs = {output_value_info}; + Ort::ThrowOnError(model_editor_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size())); + Ort::ThrowOnError(model_editor_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size())); + input_value_info = nullptr; // graph now owns the input/output values + output_value_info = nullptr; + + // + // Gemm node + // + + OrtOpAttr* alpha_attr = nullptr; + float alpha_value = 2.0; + Ort::ThrowOnError(api.CreateOpAttr("alpha", &alpha_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT, &alpha_attr)); + + std::vector node_input_names = {"X", "Y"}; + const std::string gemm_output_name = use_constant_node ? "Z_temp" : "Z"; + std::vector node_output_names = {gemm_output_name.c_str()}; + std::vector node_attributes{alpha_attr}; + OrtNode* node = CreateNode(model_editor_api, "Gemm", "Gemm1", node_input_names, node_output_names, node_attributes); + alpha_attr = nullptr; // Node now owns + + Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + + // Y input + // As it's 128 bytes it could either be allocated using CreateTensorAsOrtValue or use existing memory. + // Under 128 bytes must use CreateTensorAsOrtValue. + std::vector y_dims = {4, 8}; + + deleter.weights.emplace_back(std::make_unique>(32)); + auto& y_values = *deleter.weights.back(); + std::iota(y_values.begin(), y_values.end(), 1.0f); + + // create an initializer for the Y input. add to `weights` so the memory remains valid. + OrtValue* y_tensor = nullptr; + Ort::ThrowOnError( + api.CreateTensorWithDataAndDeleterAsOrtValue(&deleter, + y_values.data(), y_values.size() * sizeof(y_values[0]), + y_dims.data(), y_dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + &y_tensor)); + + Ort::ThrowOnError(model_editor_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true)); + y_tensor = nullptr; // graph now owns + + if (use_constant_node) { + // Test that a Constant node is converted to an initializer + + // create Constant nodes for min/max to limit output range + OrtOpAttr* min_attr = nullptr; + float min = 400.0f; + Ort::ThrowOnError(api.CreateOpAttr("value", &min, sizeof(min), ORT_OP_ATTR_FLOAT, &min_attr)); + node = CreateNode(model_editor_api, "Constant", "clip_min", {}, {"min"}, {min_attr}); + Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + + OrtOpAttr* max_attr = nullptr; + float max = 900.0f; + Ort::ThrowOnError(api.CreateOpAttr("value", &max, sizeof(max), ORT_OP_ATTR_FLOAT, &max_attr)); + node = CreateNode(model_editor_api, "Constant", "clip_max", {}, {"max"}, {max_attr}); + Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + + node = CreateNode(model_editor_api, "Clip", "Clip1", {gemm_output_name.c_str(), "min", "max"}, {"Z"}); + Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + } + + std::vector domain_names = {onnxruntime::kOnnxDomain}; + std::vector opset_versions = {18}; + Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), domain_names.size(), + &model)); + Ort::ThrowOnError(model_editor_api.AddGraphToModel(model, graph)); + graph = nullptr; // model now owns + }; + + auto run_test = [&](bool use_constant_node) -> void { + OrtModel* model = nullptr; + build_model(use_constant_node, model); + + ASSERT_NE(model, nullptr) << "build_model should have created a model"; + + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = "X"; + input.dims = {3, 4}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, + 8.0f, 7.0f, 6.0f, 5.0f, + 9.0f, 3.0f, 5.0f, 7.0f}; + + std::vector expected_dims = {3, 8}; + Model cxx_model(model); + auto session = CreateSession(*ort_env, cxx_model); + + std::vector expected_output; + if (use_constant_node) { + // clipped with min 400 and max 900 + expected_output = {400.0f, 400.0f, 400.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, + 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 900.0f, 900.0f, + 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 900.0f}; + } else { + expected_output = {340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, + 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f, + 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 928.0f}; + } + + TestInference(session, inputs, "Z", expected_dims, expected_output); + + api.ReleaseSession(session.release()); + + ASSERT_EQ(deleter.weights.size(), size_t(0)) << "All weights should have been freed"; + }; + + run_test(false); + run_test(true); // use Constant node for initializer +} + +TEST(ModelEditorAPITest, Basic_CxxApi) { + // initializers that are used directly by the model. as there's no copy they must remain valid + std::vector>> weights; + + Ort::Graph graph; + + // + // Create OrtModel with a Gemm. X input is 3x4, Y input is 4x8, Z output is 3x8. + // X is model input. Y is initializer. + // Set the alpha attribute of the Gemm node to 2.0 to test attribute handling. + // + + std::vector graph_inputs; + std::vector graph_outputs; + + // model input. it's {3, 4} but use a symbolic dim to test that works. + std::vector input_dims({-1, 4}); + std::vector input_symbolic_dims({"multiple_of_3", ""}); + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + input_dims, + &input_symbolic_dims); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs.emplace_back("X", input_type_info.GetConst()); + + // model outputs + std::vector output_dims = {-1, 8}; + std::vector output_symbolic_dims({"multiple_of_3", ""}); + TensorTypeAndShapeInfo output_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + output_dims, + &output_symbolic_dims); + auto output_type_info = TypeInfo::CreateTensorInfo(output_tensor_info.GetConst()); + graph_outputs.emplace_back("Z", output_type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + // + // Gemm node + // + + std::vector attributes; + float alpha_value = 2.0; + attributes.push_back(OpAttr("alpha", &alpha_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT)); + + Node node("Gemm", onnxruntime::kOnnxDomain, "Gemm1", {"X", "Y"}, {"Z"}, attributes); + + graph.AddNode(node); + + // create an initializer for the Y input. + // add to `weights` so it remains valid for the lifetime of the session and we can avoid copying the data. + // As it's 128 bytes it could either be allocated using CreateTensorAsOrtValue or use existing memory. + // Under 128 bytes must use CreateTensorAsOrtValue. + std::vector y_dims = {4, 8}; + + weights.emplace_back(std::make_unique>(32)); + auto& y_values = *weights.back(); + std::iota(y_values.begin(), y_values.end(), 1.0f); + + auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + // if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession + auto y_tensor = Value::CreateTensor(info, y_values.data(), y_values.size(), y_dims.data(), y_dims.size()); + graph.AddInitializer("Y", y_tensor, /*data is external*/ true); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = "X"; + input.dims = {3, 4}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, + 8.0f, 7.0f, 6.0f, 5.0f, + 9.0f, 3.0f, 5.0f, 7.0f}; + + std::vector expected_dims = {3, 8}; + + auto session = CreateSession(*ort_env, model); + TestInference(session, inputs, "Z", expected_dims, + {340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, + 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f, + 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 928.0f}); +} + +TEST(ModelEditorAPITest, BasicModelEdit_CxxApi) { + // + // Load existing model + // Add Cast to change the model input from float to int64 + // Update model inputs to match + // Run + // + + SessionOptions so; + + // Set this to save the model if you want to debug. + // so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx")); + + Session session = Session::CreateModelEditorSession(*ort_env, TSTR("testdata/mnist.onnx"), so); + + ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string + + // we augment the original model with nodes, initializers and the updated model inputs/outputs from this model. + // the original graph is unchanged. nodes can be added before/after it. initializers can be added. + // new nodes must conform to the original domain:opset of the model. + // additional operator domain:opset pairs can be added. + std::vector opsets; // no additional opsets required + Model model(opsets); + + std::vector graph_inputs = session.GetInputs(); + ASSERT_EQ(graph_inputs.size(), size_t(1)); + ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + // typically this isn't needed. we replace this input but need to read info from it later on in the test + // validation so we save the info locally to keep it accessible. + auto orig_input_name = graph_inputs[0].Name(); + auto input_shape = graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape(); + const std::string new_input_name = "Int64Input"; + + // Add Cast node to convert input from float to int64 + std::vector attributes; + int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT)); + + Ort::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, + // the existing node will now consume the output from the Cast instead of a graph input + {orig_input_name}, + attributes); + + // we're replacing the only input. the shape is the same but the name and data type change. + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + input_shape); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs[0] = ValueInfo(new_input_name, input_type_info.GetConst()); + + Graph graph; // new info to augment the model with + + graph.AddNode(node); + graph.SetInputs(graph_inputs); + + // the node we added does not require any new opsets. + model.AddGraph(graph); + session.FinalizeModelEditorSession(model, so); + + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = new_input_name.c_str(); + input.dims = input_shape; + + auto num_values = std::accumulate(input.dims.begin(), input.dims.end(), int64_t(1), std::multiplies()); + input.values.resize(size_t(num_values)); + std::iota(input.values.begin(), input.values.end(), 1); + + std::vector expected_dims = {1, 10}; + std::vector expected_output = {-48.5088f, -1040.2948f, -347.0959f, 101.7392f, 421.3352f, + 750.92145f, 231.5060f, -1694.4152f, 681.5623f, 378.1689f}; + + TestInference(session, inputs, session.GetOutputNames()[0].c_str(), expected_dims, expected_output); + + // double check with original model + { + SessionOptions expected_so; + Session expected_session = Session(*ort_env, TSTR("testdata/mnist.onnx"), expected_so); + std::vector> expected_inputs(1); + auto& expected_input = expected_inputs[0]; + expected_input.name = orig_input_name.c_str(); + expected_input.dims = input_shape; + expected_input.values.reserve(size_t(num_values)); + std::transform(input.values.begin(), input.values.end(), std::back_inserter(expected_input.values), + [&](int64_t value) { return float(value); }); + + TestInference(expected_session, expected_inputs, session.GetOutputNames()[0].c_str(), + expected_dims, expected_output); + } +} + +TEST(ModelEditorAPITest, InvalidDimension) { + try { + std::vector input_dims = {-2, 2}; + TensorTypeAndShapeInfo tensor_type_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + input_dims); + // invalid dim of -2 should cause exception + TypeInfo::CreateTensorInfo(tensor_type_info.GetConst()); + FAIL() << "Expected exception for invalid dimension"; + } catch (const Ort::Exception& e) { + ASSERT_STREQ(e.what(), "dim_values must be -1 (symbolic dimension) or larger."); + } +} + +TEST(ModelEditorAPITest, CreateInvalidModel_NoOpsets) { + Ort::Graph graph; + std::vector graph_inputs; + std::vector graph_outputs; + + std::vector dims({4}); + TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); + auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); + graph_inputs.emplace_back("X", type_info.GetConst()); + graph_outputs.emplace_back("Z", type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + Ort::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "X"}, {"Z"}); + + graph.AddNode(node); + + std::vector opsets; + Model model(opsets); + model.AddGraph(graph); + + try { + auto session = CreateSession(*ort_env, model); + FAIL(); + } catch (const Ort::Exception& e) { + ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain")); + } +} + +TEST(ModelEditorAPITest, CreateInvalidModel_MissingValue) { + Ort::Graph graph; + + std::vector graph_inputs; + std::vector graph_outputs; + + std::vector dims({4}); + TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); + auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); + graph_inputs.emplace_back("X", type_info.GetConst()); + graph_outputs.emplace_back("Z", type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + Ort::Node node("Add", onnxruntime::kOnnxDomain, "Add1", {"X", "missing"}, {"Z"}); + graph.AddNode(node); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + try { + auto session = CreateSession(*ort_env, model); + FAIL(); + } catch (const Ort::Exception& e) { + ASSERT_THAT(e.what(), ::testing::HasSubstr("Node input 'missing' is not a graph input, " + "initializer, or output of a previous node.")); + } +} + +TEST(ModelEditorAPITest, InvalidModelEdit) { + // Add a node but make the edit invalid in various ways + // - add node but don't update graph inputs + // - add node with invalid domain + const auto edit_model = [](bool invalid_domain) { + SessionOptions so; + + // Set this to save the model if you want to debug. + // so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx")); + + Session session = Session::CreateModelEditorSession(*ort_env, TSTR("testdata/mnist.onnx"), so); + + ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string + + std::vector opsets; // no additional opsets required + Model model(opsets); + Graph graph; // new info to augment the model with + + const char* domain = invalid_domain ? "invalid_domain" : onnxruntime::kOnnxDomain; + + std::vector graph_inputs = session.GetInputs(); + ASSERT_EQ(graph_inputs.size(), size_t(1)); + ASSERT_EQ(graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetElementType(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + const std::string new_input_name = "Int64Input"; + + // Add Cast node to convert input from float to int64 + std::vector attributes; + int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT)); + + Node node("Cast", domain, "NewInputNode", {new_input_name}, + // the existing node will now consume the output from the Cast instead of a graph input + {graph_inputs[0].Name()}, + attributes); + graph.AddNode(node); + + if (invalid_domain) { + // we're replacing the only input. the shape is the same but the name and data type change. + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape()); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs[0] = ValueInfo(new_input_name, input_type_info.GetConst()); + graph.SetInputs(graph_inputs); + } else { + // model should be invalid as we didn't connect the new node up to the graph inputs + } + + // the node we added does not require any new opsets. + model.AddGraph(graph); + + try { + session.FinalizeModelEditorSession(model, so); + FAIL() << "Should have failed to resolve graph due to invalid edits."; + } catch (const Ort::Exception& e) { + if (invalid_domain) { + ASSERT_THAT(e.what(), ::testing::HasSubstr("Error No opset import for domain 'invalid_domain'")); + } else { + ASSERT_THAT(e.what(), ::testing::HasSubstr("This is an invalid model")); + } + } + }; + + edit_model(false); + edit_model(true); // add node with invalid domain +} + +TEST(ModelEditorAPITest, CreateTypeInfo) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + TensorTypeAndShapeInfo base_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + {2, 4}); + + OrtTypeInfo* base_tensor_type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(base_tensor_info, &base_tensor_type_info)); + + ONNXType onnx_type = ONNX_TYPE_UNKNOWN; + const OrtTensorTypeAndShapeInfo* tensor_info = nullptr; + ONNXTensorElementDataType onnx_element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + // sparse tensor + OrtTypeInfo* sparse_tensor_type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateSparseTensorTypeInfo(base_tensor_info, &sparse_tensor_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sparse_tensor_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_SPARSETENSOR); + Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(sparse_tensor_type_info, &tensor_info)); + Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type)); + ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + api.ReleaseTypeInfo(sparse_tensor_type_info); + + // sequence + OrtTypeInfo* sequence_type_info = nullptr; + const OrtSequenceTypeInfo* sequence_info = nullptr; + OrtTypeInfo* sequence_element_type_info = nullptr; + + Ort::ThrowOnError(model_editor_api.CreateSequenceTypeInfo(base_tensor_type_info, &sequence_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sequence_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_SEQUENCE); + Ort::ThrowOnError(api.CastTypeInfoToSequenceTypeInfo(sequence_type_info, &sequence_info)); + Ort::ThrowOnError(api.GetSequenceElementType(sequence_info, &sequence_element_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(sequence_element_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR); + Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(sequence_element_type_info, &tensor_info)); + Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type)); + ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + api.ReleaseTypeInfo(sequence_element_type_info); + api.ReleaseTypeInfo(sequence_type_info); + + // map + OrtTypeInfo* map_type_info = nullptr; + const OrtMapTypeInfo* map_info = nullptr; + ONNXTensorElementDataType map_key_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + OrtTypeInfo* map_value_type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateMapTypeInfo(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, base_tensor_type_info, + &map_type_info)); // clones map_type_info + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(map_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_MAP); + Ort::ThrowOnError(api.CastTypeInfoToMapTypeInfo(map_type_info, &map_info)); + Ort::ThrowOnError(api.GetMapKeyType(map_info, &map_key_type)); + ASSERT_EQ(map_key_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); + Ort::ThrowOnError(api.GetMapValueType(map_info, &map_value_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(map_value_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR); + Ort::ThrowOnError(api.CastTypeInfoToTensorInfo(map_value_type_info, &tensor_info)); + Ort::ThrowOnError(api.GetTensorElementType(tensor_info, &onnx_element_type)); + ASSERT_EQ(onnx_element_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + api.ReleaseTypeInfo(map_value_type_info); + api.ReleaseTypeInfo(map_type_info); + + // optional + OrtTypeInfo* optional_type_info = nullptr; + const OrtOptionalTypeInfo* optional_info = nullptr; + OrtTypeInfo* optional_contained_type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateOptionalTypeInfo(base_tensor_type_info, &optional_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(optional_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_OPTIONAL); + Ort::ThrowOnError(api.CastTypeInfoToOptionalTypeInfo(optional_type_info, &optional_info)); + Ort::ThrowOnError(api.GetOptionalContainedTypeInfo(optional_info, &optional_contained_type_info)); + Ort::ThrowOnError(api.GetOnnxTypeFromTypeInfo(optional_contained_type_info, &onnx_type)); + ASSERT_EQ(onnx_type, ONNXType::ONNX_TYPE_TENSOR); + api.ReleaseTypeInfo(optional_contained_type_info); + api.ReleaseTypeInfo(optional_type_info); + + api.ReleaseTypeInfo(base_tensor_type_info); +} diff --git a/onnxruntime/test/shared_lib/test_ort_format_models.cc b/onnxruntime/test/shared_lib/test_ort_format_models.cc index 99a9ebc3362ae..b3491e3476f23 100644 --- a/onnxruntime/test/shared_lib/test_ort_format_models.cc +++ b/onnxruntime/test/shared_lib/test_ort_format_models.cc @@ -17,7 +17,7 @@ extern std::unique_ptr ort_env; [[maybe_unused]] static void TestInference(Ort::Env& env, const std::basic_string& model_uri, - const std::vector& inputs, const char* output_name, + const std::vector>& inputs, const char* output_name, const std::vector& expected_dims_y, const std::vector& expected_values_y, Ort::CustomOpDomain& custom_op_domain, void* cuda_compute_stream = nullptr) { Ort::SessionOptions session_options; @@ -100,8 +100,8 @@ TEST(OrtFormatCustomOpTests, ConvertOnnxModelToOrt) { } // now load the ORT format model and execute it - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -130,8 +130,8 @@ TEST(OrtFormatCustomOpTests, LoadOrtModel) { custom_op_domain.Add(&custom_op); // load the ORT format model and execute it - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; @@ -151,8 +151,8 @@ TEST(OrtFormatCustomOpTests, LoadOrtModelStandaloneCustomOpImplementation) { custom_op_domain.Add(&standalone_op); // load the ORT format model and execute it - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; diff --git a/onnxruntime/test/shared_lib/utils.h b/onnxruntime/test/shared_lib/utils.h index 483753f2ae6b2..5d15582b86cb9 100644 --- a/onnxruntime/test/shared_lib/utils.h +++ b/onnxruntime/test/shared_lib/utils.h @@ -5,4 +5,56 @@ #include "core/session/onnxruntime_cxx_api.h" +#include "gtest/gtest.h" + OrtCUDAProviderOptions CreateDefaultOrtCudaProviderOptionsWithCustomStream(void* cuda_compute_stream = nullptr); + +template +struct Input { + const char* name = nullptr; + std::vector dims; + std::vector values; +}; + +template > +void RunSession(OrtAllocator* allocator, + Ort::Session& session_object, + const std::vector& inputs, + const char* output_name, + const std::vector& output_dims, + const std::vector& expected_output, + Ort::Value* output_tensor) { + std::vector ort_inputs; + std::vector input_names; + for (size_t i = 0; i < inputs.size(); i++) { + input_names.emplace_back(inputs[i].name); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), + inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); + } + + std::vector ort_outputs; + if (output_tensor) + session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, output_tensor, 1); + else { + ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, 1); + ASSERT_EQ(ort_outputs.size(), 1u); + output_tensor = &ort_outputs[0]; + } + + auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); + ASSERT_EQ(type_info.GetShape(), output_dims); + size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(expected_output.size(), total_len); + + auto* actual = output_tensor->GetTensorMutableData(); + for (size_t i = 0; i != total_len; ++i) { + if constexpr (std::is_same::value || std::is_same::value) { + EXPECT_NEAR(expected_output[i], actual[i], 1e-3) << "i=" << i; + } else { + EXPECT_EQ(expected_output[i], actual[i]) << "i=" << i; + } + } +} diff --git a/onnxruntime/test/testdata/cast_float_to_double.onnx b/onnxruntime/test/testdata/cast_float_to_double.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dc7997cddd8a8c762e354316662fb0d734e25e86 GIT binary patch literal 136 zcmdfpOwLZtOVKS!EiSPt;8NgX&CDw(EfHeNFD(JmN-WNa#U)ytTudeT65I-kD&v!ZqVaA%{*EE>CHe6#{-I7ju2JGJ&3s%u9E?I7TudCyK+KXP!38x=2qeRe Mka1$+Vh|7o0L&R4`v3p{ literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc index 57471f7c029c2..27a4b06a99e64 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.cc @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Confidential and Proprietary. +// Licensed under the MIT License. #include "my_execution_provider.h" #include "my_allocator.h" diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h index ff0c7e80c4eeb..efb359a9e5e43 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_execution_provider.h @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Confidential and Proprietary. +// Licensed under the MIT License. #pragma once diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 7adfc6a2b2ccb..1ad35b51bb1c1 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -8,6 +8,14 @@ #include "core/session/onnxruntime_cxx_api.h" #include "api.h" +#ifdef USE_WEBGPU +namespace onnxruntime { +namespace webgpu { +WGPUDevice GetDevice(int); +} +} // namespace onnxruntime +#endif + #include #include #include @@ -164,8 +172,12 @@ OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, return UNREGISTER_AUTO_RELEASE(session_options); } -int OrtAppendExecutionProvider(ort_session_options_handle_t session_options, const char* name) { - return CHECK_STATUS(SessionOptionsAppendExecutionProvider, session_options, name, nullptr, nullptr, 0); +int OrtAppendExecutionProvider(ort_session_options_handle_t session_options, + const char* name, + const char* const* provider_options_keys, + const char* const* provider_options_values, + size_t num_keys) { + return CHECK_STATUS(SessionOptionsAppendExecutionProvider, session_options, name, provider_options_keys, provider_options_values, num_keys); } int OrtAddFreeDimensionOverride(ort_session_options_handle_t session_options, @@ -507,6 +519,16 @@ char* OrtEndProfiling(ort_session_handle_t session) { : nullptr; } +// WebGPU API Section + +#ifdef USE_WEBGPU + +WGPUDevice OrtGetWebGpuDevice(int device_id) { + return onnxruntime::webgpu::GetDevice(device_id); +} + +#endif + // Training API Section #ifdef ENABLE_TRAINING_APIS diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index f44c515d98f6b..9ff1eb55ecedc 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -10,6 +10,10 @@ #include +#ifdef USE_WEBGPU +#include +#endif + #include struct OrtSession; @@ -85,7 +89,10 @@ ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtAppendExecutionProvider(ort_session_options_handle_t session_options, - const char* name); + const char* name, + const char* const* provider_options_keys, + const char* const* provider_options_values, + size_t num_keys); /** * add a free dimension override for one dimension of a session's input. @@ -294,6 +301,21 @@ int EMSCRIPTEN_KEEPALIVE OrtRun(ort_session_handle_t session, */ char* EMSCRIPTEN_KEEPALIVE OrtEndProfiling(ort_session_handle_t session); +// WebGPU API Section + +#ifdef USE_WEBGPU + +/** + * get the GPU Device by device ID. + * + * This function is only available after the GPU Device is initialized in WebGpuContextFactory. + * + * @returns a WGPUDevice handle. + */ +WGPUDevice EMSCRIPTEN_KEEPALIVE OrtGetWebGpuDevice(int device_id); + +#endif + // Training API Section #ifdef ENABLE_TRAINING_APIS diff --git a/onnxruntime/wasm/js_post_js.js b/onnxruntime/wasm/js_post_js.js index b77d82fbd7d10..56d3246fd07f0 100644 --- a/onnxruntime/wasm/js_post_js.js +++ b/onnxruntime/wasm/js_post_js.js @@ -2,6 +2,4 @@ // Licensed under the MIT License. -'use strict'; - Module["PTR_SIZE"] = 4; diff --git a/onnxruntime/wasm/js_post_js_64.js b/onnxruntime/wasm/js_post_js_64.js index b140df927ebbd..cfd79523f7900 100644 --- a/onnxruntime/wasm/js_post_js_64.js +++ b/onnxruntime/wasm/js_post_js_64.js @@ -2,6 +2,4 @@ // Licensed under the MIT License. -'use strict'; - Module["PTR_SIZE"] = 8; diff --git a/onnxruntime/wasm/post-webgpu.js b/onnxruntime/wasm/post-webgpu.js new file mode 100644 index 0000000000000..146355f6a44d3 --- /dev/null +++ b/onnxruntime/wasm/post-webgpu.js @@ -0,0 +1,261 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This file contains the post-run code for the ORT WebAssembly module. The code in this file will be injected into the +// final module using Emscripten's `--post-js` option. +// +// This file will only be used in build with flag `--use_webgpu`. + +/** + * This function is called only once when initializing the WebGPU backend. + * + * @param {(gpuDevice: GPUDevice) => void} setDefaultDevice A callback function to set the default device. + */ +Module["webgpuInit"] = (setDefaultDevice) => { + /** + * a map from GPUDevice to [deviceId, instanceHandle, deviceHandle] + * + * only stores custom devices (ie. devices created by the user, not the default device created by ORT) + * + * key is the GPUDevice object. + * + * value is a tuple of 3 elements: + * - deviceId: a unique ID for the device. Must be positive integer. + * - instanceHandle: the instance handle(pointer) of the device. + * - deviceHandle: the device handle(pointer) of the device. + * + * @type {WeakMap} + */ + const webgpuActiveDevices = new WeakMap(); + /** + * a number that is used to assign a unique ID to the next custom device. + */ + let webgpuNextDeviceId = 1; + /** + * a function to set the default device. + * + * @type {(gpuDevice: GPUDevice) => void} + */ + const webgpuSetDefaultDevice = setDefaultDevice; + /** + * the current device that is being used to create a WebGPU EP inference session. + * + * the value of this variable is only valid during the creation of a WebGPU EP inference session. + * + * @type {GPUDevice|undefined} + */ + let webgpuCurrentDevice = undefined; + /** + * the current device ID that is being used to create a WebGPU EP inference session. + * + * the value of this variable is only valid during the creation of a WebGPU EP inference session. + * + * @type {number|undefined} + */ + let webgpuCurrentDeviceId = undefined; + + /** + * This function is called only when a custom device is used, during preparation of session options. + * + * @param {GPUDevice} device the user provided device object. + * @returns {undefined|[number, number, number]} a tuple of device id, instance handle, and device handle. + */ + Module["webgpuRegisterDevice"] = (device) => { + if (webgpuCurrentDeviceId !== undefined) { + throw new Error("another WebGPU EP inference session is being created."); + } + + if (device) { + let deviceInfo = webgpuActiveDevices.get(device); + if (!deviceInfo) { + const instanceHandle = _wgpuCreateInstance(0); + const deviceHandle = WebGPU.importJsDevice(device, instanceHandle); + deviceInfo = [webgpuNextDeviceId++, instanceHandle, deviceHandle]; + webgpuActiveDevices.set(device, deviceInfo); + } + + // The current device ID is a temporary storage for the device ID to be used in the session that is being created. + // + // Soon after `webgpuRegisterDevice` (this function) is called, `webgpuOnCreateSession` will be called so that the + // value of `webgpuCurrentDeviceId` is used and reset then. + webgpuCurrentDevice = device; + webgpuCurrentDeviceId = deviceInfo[0]; + return deviceInfo; + } else { + webgpuCurrentDevice = undefined; + webgpuCurrentDeviceId = 0; + return undefined; + } + }; + + const webgpuActiveSessions = new Map(); + Module["webgpuOnCreateSession"] = (sessionHandle) => { + if (webgpuCurrentDeviceId === undefined) { + // do nothing if webgpuCurrentDeviceId is undefined. + // this means no WebGPU EP is being created. + return; + } + + const deviceId = webgpuCurrentDeviceId; + webgpuCurrentDeviceId = undefined; + + if (sessionHandle) { + // when session created successfully + const deviceHandle = _OrtGetWebGpuDevice(deviceId); + webgpuActiveSessions.set(sessionHandle, deviceHandle); + + if (deviceId === 0) { + const device = webgpuCurrentDevice ?? WebGPU.getJsObject(deviceHandle); + webgpuSetDefaultDevice(device); + } + } + webgpuCurrentDevice = undefined; + }; + + Module["webgpuOnReleaseSession"] = (sessionHandle) => { + webgpuActiveSessions.delete(sessionHandle); + }; + + const gpuBufferMetadataSymbol = Symbol("gpuBufferMetadata"); + + Module["webgpuRegisterBuffer"] = (buffer, sessionHandle, bufferHandle) => { + if (bufferHandle) { + // This is a buffer that was created by ORT. Metadata is [bufferHandle, NaN] + + buffer[gpuBufferMetadataSymbol] = [bufferHandle, NaN]; + return bufferHandle; + } else { + // This is a buffer that was created by the user. Metadata is [bufferHandle, refCount] + + const metadata = buffer[gpuBufferMetadataSymbol]; + if (metadata) { + metadata[1]++; + return metadata[0]; + } + + const deviceHandle = webgpuActiveSessions.get(sessionHandle); + if (deviceHandle === undefined) { + throw new Error( + "Invalid session handle passed to webgpuRegisterBuffer" + ); + } + + const bufferHandle = WebGPU.importJsBuffer(buffer, deviceHandle); + buffer[gpuBufferMetadataSymbol] = [bufferHandle, 1]; + return bufferHandle; + } + }; + + Module["webgpuUnregisterBuffer"] = (buffer) => { + const metadata = buffer[gpuBufferMetadataSymbol]; + if (!metadata) { + throw new Error("Buffer is not registered"); + } + metadata[1]--; + // For buffers created by ORT, metadata[1] will always be NaN. This function will not release the buffer. + // Instead, the buffer will be released when user calls `Tensor.dispose()` in JavaScript. + if (metadata[1] === 0) { + _wgpuBufferRelease(metadata[0]); + delete buffer[gpuBufferMetadataSymbol]; + } + }; + + Module["webgpuGetBuffer"] = (bufferHandle) => { + return WebGPU.getJsObject(bufferHandle); + }; + + Module["webgpuCreateDownloader"] = (gpuBuffer, bufferSize, sessionHandle) => { + const deviceHandle = webgpuActiveSessions.get(sessionHandle); + if (deviceHandle === undefined) { + throw new Error("Invalid session handle passed to webgpuRegisterBuffer"); + } + + const buffer = gpuBuffer; + const device = WebGPU.getJsObject(deviceHandle); + const originalSize = bufferSize; + const size = Math.ceil(Number(originalSize) / 16) * 16; + + return async () => { + // prettier-ignore + // + // the line above is used to force prettier to skip formatting the next statement. + // this is because prettier will remove the quotes around the property names, but we need to keep them + // because otherwise closure compiler may rename them and break the code. + const gpuReadBufferDescriptor = { + "size": size, + "usage": 9 /* GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ */, + }; + const gpuReadBuffer = device.createBuffer(gpuReadBufferDescriptor); + try { + const commandEncoder = device.createCommandEncoder(); + commandEncoder.copyBufferToBuffer( + buffer /* source buffer */, + 0 /* source offset */, + gpuReadBuffer /* destination buffer */, + 0 /* destination offset */, + size /* size */ + ); + device.queue.submit([commandEncoder.finish()]); + + await gpuReadBuffer.mapAsync(GPUMapMode.READ); + + const arrayBuffer = gpuReadBuffer.getMappedRange(); + return arrayBuffer.slice(0, originalSize); + } finally { + gpuReadBuffer.destroy(); + } + }; + }; + + // Setup a callback function for loading external buffers (model weights). + Module.webgpuUploadExternalBuffer = (bufferHandle, data) => { + const srcArrayBuffer = data.buffer; + const srcOffset = data.byteOffset; + const srcLength = data.byteLength; + const size = Math.ceil(Number(srcLength) / 16) * 16; + + const gpuBuffer = WebGPU.getJsObject(bufferHandle); + + // get current device + if (!webgpuCurrentDevice) { + const deviceHandle = _OrtGetWebGpuDevice(webgpuCurrentDeviceId); + webgpuCurrentDevice = WebGPU.getJsObject(deviceHandle); + } + + // create gpu buffer + + // prettier-ignore + // + // the line above is used to force prettier to skip formatting the next statement. + // this is because prettier will remove the quotes around the property names, but we need to keep them + // because otherwise closure compiler may rename them and break the code. + const gpuBufferForUploadingDescriptor = { + "mappedAtCreation": true, + "size": size, + "usage": 6 /* GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC */, + }; + const gpuBufferForUploading = webgpuCurrentDevice.createBuffer( + gpuBufferForUploadingDescriptor + ); + + // copy (upload) data + const arrayBuffer = gpuBufferForUploading.getMappedRange(); + new Uint8Array(arrayBuffer).set( + new Uint8Array(srcArrayBuffer, srcOffset, srcLength) + ); + gpuBufferForUploading.unmap(); + + // GPU copy + const commandEncoder = webgpuCurrentDevice.createCommandEncoder(); + commandEncoder.copyBufferToBuffer( + gpuBufferForUploading, + 0, + gpuBuffer, + 0, + size + ); + webgpuCurrentDevice.queue.submit([commandEncoder.finish()]); + gpuBufferForUploading.destroy(); + }; +}; diff --git a/onnxruntime/wasm/pre-async.js b/onnxruntime/wasm/pre-async.js new file mode 100644 index 0000000000000..8c75dc7c5cf1e --- /dev/null +++ b/onnxruntime/wasm/pre-async.js @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the +// final module using Emscripten's `--pre-js` option. +// +// This file will only be used in build with flag `-s ASYNCIFY=1`. + +/** + * initialize for asyncify support. + */ +let initAsyncImpl = () => { + // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) + // It removes some overhead in cwarp() and ccall() that we don't need. + // + // Currently in ASYNCIFY build, we only use this for the following functions: + // - OrtCreateSession() + // - OrtRun() + // - OrtRunWithBinding() + // - OrtBindInput() + // + // Note: about parameters "getFunc" and "setFunc": + // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. + // + // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a + // wrapper for OrtRun() like this (minified): + // ``` + // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); + // ``` + // + // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates + // a wrapper for OrtRun() like this (minified): + // ``` + // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // + // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once + // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will + // reset d._OrtRun to J.ka when the first time it is called. + // + // The difference is important because we need to design the async wrapper in a way that it can handle both cases. + // + // Now, let's look at how the async wrapper is designed to work for both cases: + // + // - Debug build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. + // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // - Release build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. + // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this + // function: + // ``` + // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). + // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored + // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release + // build. + // + // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an + // exported function and set the new value of an exported function. + // + const wrapAsync = (func, getFunc, setFunc) => { + return (...args) => { + // cache the async data before calling the function. + const previousAsync = Asyncify.currData; + + const previousFunc = getFunc?.(); + const ret = func(...args); + const newFunc = getFunc?.(); + if (previousFunc !== newFunc) { + // The exported function has been updated. + // Set the sync function reference to the new function. + func = newFunc; + // Set the exported function back to the async wrapper. + setFunc(previousFunc); + // Remove getFunc and setFunc. They are no longer needed. + setFunc = null; + getFunc = null; + } + + // If the async data has been changed, it means that the function started an async operation. + if (Asyncify.currData != previousAsync) { + // returns the promise + return Asyncify.whenDone(); + } + // the function is synchronous. returns the result. + return ret; + }; + }; + + // replace the original functions with asyncified versions + const wrapAsyncAPIs = (funcNames) => { + for (const funcName of funcNames) { + Module[funcName] = wrapAsync( + Module[funcName], + () => Module[funcName], + (v) => (Module[funcName] = v) + ); + } + }; + + wrapAsyncAPIs([ + "_OrtAppendExecutionProvider", + "_OrtCreateSession", + "_OrtRun", + "_OrtRunWithBinding", + "_OrtBindInput", + ]); + + // If JSEP is enabled, wrap OrtRun() and OrtRunWithBinding() with asyncify. + if (typeof jsepRunAsync !== "undefined") { + Module["_OrtRun"] = jsepRunAsync(Module["_OrtRun"]); + Module["_OrtRunWithBinding"] = jsepRunAsync(Module["_OrtRunWithBinding"]); + } + + // remove this function to make sure it is called only once. + initAsyncImpl = undefined; +}; + +Module["asyncInit"] = () => { + initAsyncImpl?.(); +}; diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 0c83e71a921cb..5b2f044d4c27b 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -1,255 +1,157 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -'use strict'; - // // This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the // final module using Emscripten's `--pre-js` option. // // This file will only be used in build with flag `--use_jsep`. - -/** - * initialize JSEP for asyncify support. - */ -let jsepInitAsync = () => { - // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) - // It removes some overhead in cwarp() and ccall() that we don't need. - // - // Currently in JSEP build, we only use this for the following functions: - // - OrtRun() - // - OrtRunWithBinding() - // - OrtBindInput() - // - // Note: about parameters "getFunc" and "setFunc": - // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. - // - // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a - // wrapper for OrtRun() like this (minified): - // ``` - // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); - // ``` - // - // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates - // a wrapper for OrtRun() like this (minified): - // ``` - // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // - // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once - // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will - // reset d._OrtRun to J.ka when the first time it is called. - // - // The difference is important because we need to design the async wrapper in a way that it can handle both cases. - // - // Now, let's look at how the async wrapper is designed to work for both cases: - // - // - Debug build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. - // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // - Release build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. - // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this - // function: - // ``` - // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). - // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored - // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release - // build. - // - // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an - // exported function and set the new value of an exported function. - // - const jsepWrapAsync = (func, getFunc, setFunc) => { - return (...args) => { - // cache the async data before calling the function. - const previousAsync = Asyncify.currData; - - const previousFunc = getFunc?.(); - const ret = func(...args); - const newFunc = getFunc?.(); - if (previousFunc !== newFunc) { - // The exported function has been updated. - // Set the sync function reference to the new function. - func = newFunc; - // Set the exported function back to the async wrapper. - setFunc(previousFunc); - // Remove getFunc and setFunc. They are no longer needed. - setFunc = null; - getFunc = null; +// This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly. +const jsepRunAsync = (runAsyncFunc) => { + return async (...args) => { + try { + // Module.jsepSessionState should be null, unless we are in the middle of a session. + // If it is not null, it means that the previous session has not finished yet. + if (Module.jsepSessionState) { + throw new Error("Session already started"); } + const state = (Module.jsepSessionState = { + sessionHandle: args[0], + errors: [], + }); - // If the async data has been changed, it means that the function started an async operation. - if (Asyncify.currData != previousAsync) { - // returns the promise - return Asyncify.whenDone(); - } - // the function is synchronous. returns the result. - return ret; - }; - }; - - // This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly. - const runAsync = (runAsyncFunc) => { - return async (...args) => { - try { - // Module.jsepSessionState should be null, unless we are in the middle of a session. - // If it is not null, it means that the previous session has not finished yet. - if (Module.jsepSessionState) { - throw new Error('Session already started'); - } - const state = Module.jsepSessionState = {sessionHandle: args[0], errors: []}; - - // Run the acyncified function: OrtRun() or OrtRunWithBinding() - const ret = await runAsyncFunc(...args); + // Run the acyncified function: OrtRun() or OrtRunWithBinding() + const ret = await runAsyncFunc(...args); - // Check if the session is still valid. this object should be the same as the one we set above. - if (Module.jsepSessionState !== state) { - throw new Error('Session mismatch'); - } + // Check if the session is still valid. this object should be the same as the one we set above. + if (Module.jsepSessionState !== state) { + throw new Error("Session mismatch"); + } - // Flush the backend. This will submit all pending commands to the GPU. - Module.jsepBackend?.['flush'](); + // Flush the backend. This will submit all pending commands to the GPU. + Module.jsepBackend?.["flush"](); - // Await all pending promises. This includes GPU validation promises for diagnostic purposes. - const errorPromises = state.errors; - if (errorPromises.length > 0) { - let errors = await Promise.all(errorPromises); - errors = errors.filter(e => e); - if (errors.length > 0) { - throw new Error(errors.join('\n')); - } + // Await all pending promises. This includes GPU validation promises for diagnostic purposes. + const errorPromises = state.errors; + if (errorPromises.length > 0) { + let errors = await Promise.all(errorPromises); + errors = errors.filter((e) => e); + if (errors.length > 0) { + throw new Error(errors.join("\n")); } - - return ret; - } finally { - Module.jsepSessionState = null; } - }; - }; - // replace the original functions with asyncified versions - Module['_OrtCreateSession'] = jsepWrapAsync( - Module['_OrtCreateSession'], - () => Module['_OrtCreateSession'], - v => Module['_OrtCreateSession'] = v); - Module['_OrtRun'] = runAsync(jsepWrapAsync( - Module['_OrtRun'], - () => Module['_OrtRun'], - v => Module['_OrtRun'] = v)); - Module['_OrtRunWithBinding'] = runAsync(jsepWrapAsync( - Module['_OrtRunWithBinding'], - () => Module['_OrtRunWithBinding'], - v => Module['_OrtRunWithBinding'] = v)); - Module['_OrtBindInput'] = jsepWrapAsync( - Module['_OrtBindInput'], - () => Module['_OrtBindInput'], - v => Module['_OrtBindInput'] = v); - - // remove this function to make sure it is called only once. - jsepInitAsync = undefined; + return ret; + } finally { + Module.jsepSessionState = null; + } + }; }; - /** - * initialize JSEP for WebGPU. + * initialize JSEP for WebGPU and WebNN. */ -Module['jsepInit'] = (name, params) => { - jsepInitAsync?.(); - - if (name === 'webgpu') { - [Module.jsepBackend, - Module.jsepAlloc, - Module.jsepFree, - Module.jsepCopy, - Module.jsepCopyAsync, - Module.jsepCreateKernel, - Module.jsepReleaseKernel, - Module.jsepRunKernel, - Module.jsepCaptureBegin, - Module.jsepCaptureEnd, - Module.jsepReplay] = params; +Module["jsepInit"] = (name, params) => { + if (name === "webgpu") { + [ + Module.jsepBackend, + Module.jsepAlloc, + Module.jsepFree, + Module.jsepCopy, + Module.jsepCopyAsync, + Module.jsepCreateKernel, + Module.jsepReleaseKernel, + Module.jsepRunKernel, + Module.jsepCaptureBegin, + Module.jsepCaptureEnd, + Module.jsepReplay, + ] = params; // expose webgpu backend functions const backend = Module.jsepBackend; - Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { - return backend['registerBuffer'](sessionId, index, buffer, size); + Module["jsepRegisterBuffer"] = (sessionId, index, buffer, size) => { + return backend["registerBuffer"](sessionId, index, buffer, size); }; - Module['jsepGetBuffer'] = (dataId) => { - return backend['getBuffer'](dataId); + Module["jsepGetBuffer"] = (dataId) => { + return backend["getBuffer"](dataId); }; - Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { - return backend['createDownloader'](gpuBuffer, size, type); + Module["jsepCreateDownloader"] = (gpuBuffer, size, type) => { + return backend["createDownloader"](gpuBuffer, size, type); }; - Module['jsepOnCreateSession'] = sessionId => { - backend['onCreateSession'](sessionId); + Module["jsepOnCreateSession"] = (sessionId) => { + backend["onCreateSession"](sessionId); }; - Module['jsepOnReleaseSession'] = sessionId => { - backend['onReleaseSession'](sessionId); + Module["jsepOnReleaseSession"] = (sessionId) => { + backend["onReleaseSession"](sessionId); }; - Module['jsepOnRunStart'] = sessionId => { - return backend['onRunStart'](sessionId); + Module["jsepOnRunStart"] = (sessionId) => { + return backend["onRunStart"](sessionId); }; Module.jsepUploadExternalBuffer = (dataId, buffer) => { - backend['upload'](dataId, buffer); + backend["upload"](dataId, buffer); }; - } else if (name === 'webnn') { + } else if (name === "webnn") { // Functions called from EM_ASM need to be assigned in a way that can be minified. // Functions called via emscripten::val::module_property need to be assigned by name so that the minifier doesn't // change the name. - [Module.jsepBackend, - Module.jsepReserveTensorId, - Module.jsepReleaseTensorId, - Module['jsepEnsureTensor'], - Module.jsepUploadTensor, - Module['jsepDownloadTensor'], + [ + Module.jsepBackend, + Module.jsepReserveTensorId, + Module.jsepReleaseTensorId, + Module["jsepEnsureTensor"], + Module.jsepUploadTensor, + Module["jsepDownloadTensor"], ] = params; // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. - Module['jsepReleaseTensorId'] = Module.jsepReleaseTensorId; - Module['jsepUploadTensor'] = Module.jsepUploadTensor; + Module["jsepReleaseTensorId"] = Module.jsepReleaseTensorId; + Module["jsepUploadTensor"] = Module.jsepUploadTensor; // Functions called from JS also need to have explicit names. const backend = Module.jsepBackend; - Module['jsepOnRunStart'] = sessionId => { - return backend['onRunStart'](sessionId); + Module["jsepOnRunStart"] = (sessionId) => { + return backend["onRunStart"](sessionId); }; - Module['jsepOnRunEnd'] = backend['onRunEnd'].bind(backend); - Module['jsepRegisterMLContext'] = (sessionId, mlContext) => { - backend['registerMLContext'](sessionId, mlContext); + Module["jsepOnRunEnd"] = backend["onRunEnd"].bind(backend); + Module["jsepRegisterMLContext"] = (sessionId, mlContext) => { + backend["registerMLContext"](sessionId, mlContext); }; - Module['jsepOnReleaseSession'] = sessionId => { - backend['onReleaseSession'](sessionId); + Module["jsepOnReleaseSession"] = (sessionId) => { + backend["onReleaseSession"](sessionId); }; - Module['jsepCreateMLTensorDownloader'] = (tensorId, type) => { - return backend['createMLTensorDownloader'](tensorId, type); - } - Module['jsepRegisterMLTensor'] = (sessionId, tensor, dataType, shape) => { - return backend['registerMLTensor'](sessionId, tensor, dataType, shape); + Module["jsepCreateMLTensorDownloader"] = (tensorId, type) => { + return backend["createMLTensorDownloader"](tensorId, type); + }; + Module["jsepRegisterMLTensor"] = (sessionId, tensor, dataType, shape) => { + return backend["registerMLTensor"](sessionId, tensor, dataType, shape); }; - Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { - return backend['createMLContext'](optionsOrGpuDevice); + Module["jsepCreateMLContext"] = (optionsOrGpuDevice) => { + return backend["createMLContext"](optionsOrGpuDevice); }; - Module['jsepRegisterMLConstant'] = (externalFilePath, dataOffset, dataLength, builder, desc) => { - return backend['registerMLConstant']( - externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); + Module["jsepRegisterMLConstant"] = ( + externalFilePath, + dataOffset, + dataLength, + builder, + desc + ) => { + return backend["registerMLConstant"]( + externalFilePath, + dataOffset, + dataLength, + builder, + desc, + Module.MountedFiles + ); }; - Module['jsepRegisterGraphInput'] = backend['registerGraphInput'].bind(backend); - Module['jsepIsGraphInput'] = backend['isGraphInput'].bind(backend); + Module["jsepRegisterGraphInput"] = + backend["registerGraphInput"].bind(backend); + Module["jsepIsGraphInput"] = backend["isGraphInput"].bind(backend); - Module['jsepCreateTemporaryTensor'] = backend['createTemporaryTensor'].bind(backend); + Module["jsepCreateTemporaryTensor"] = + backend["createTemporaryTensor"].bind(backend); } }; diff --git a/onnxruntime/wasm/pre.js b/onnxruntime/wasm/pre.js index 9b5f3ce545b78..636a9713519a7 100644 --- a/onnxruntime/wasm/pre.js +++ b/onnxruntime/wasm/pre.js @@ -1,21 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -'use strict'; - // // This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the // final module using Emscripten's `--pre-js` option. - /** * Mount external data files of a model to an internal map, which will be used during session initialization. * * @param {string} externalDataFilesPath * @param {Uint8Array} externalDataFilesData */ -Module['mountExternalData'] = (externalDataFilePath, externalDataFileData) => { - if (externalDataFilePath.startsWith('./')) { +Module["mountExternalData"] = (externalDataFilePath, externalDataFileData) => { + if (externalDataFilePath.startsWith("./")) { externalDataFilePath = externalDataFilePath.substring(2); } const files = Module.MountedFiles || (Module.MountedFiles = new Map()); @@ -25,7 +22,7 @@ Module['mountExternalData'] = (externalDataFilePath, externalDataFileData) => { /** * Unmount external data files of a model. */ -Module['unmountExternalData'] = () => { +Module["unmountExternalData"] = () => { delete Module.MountedFiles; }; @@ -48,5 +45,7 @@ Module['unmountExternalData'] = () => { * * @suppress {checkVars} */ -var SharedArrayBuffer = globalThis.SharedArrayBuffer ?? - new WebAssembly.Memory({'initial': 0, 'maximum': 0, 'shared': true}).buffer.constructor; +var SharedArrayBuffer = + globalThis.SharedArrayBuffer ?? + new WebAssembly.Memory({ initial: 0, maximum: 0, shared: true }).buffer + .constructor; diff --git a/setup.py b/setup.py index ced2f28e38778..53e533050b245 100644 --- a/setup.py +++ b/setup.py @@ -356,7 +356,7 @@ def finalize_options(self): "libQnnSaver.so", "libQnnSystem.so", "libHtpPrepare.so", - "onnxruntime_qnn_ctx_gen", + "ep_weight_sharing_ctx_gen", ] dl_libs.extend(qnn_deps) if nightly_build: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 8607887072347..db7dbed23a2d2 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -35,7 +35,8 @@ def version_to_tuple(version: str) -> tuple: import util.android as android # noqa: E402 from util import ( # noqa: E402 generate_android_triplets, - generate_posix_triplets, + generate_linux_triplets, + generate_macos_triplets, generate_vcpkg_triplets_for_emscripten, generate_windows_triplets, get_logger, @@ -1115,7 +1116,6 @@ def generate_build_tree( cmake_extra_args, ): log.info("Generating CMake build tree") - cmake_dir = os.path.join(source_dir, "cmake") cmake_args = [cmake_path, cmake_dir] if not use_dev_mode(args): @@ -1330,8 +1330,16 @@ def generate_build_tree( generate_android_triplets(build_dir, args.android_cpp_shared, args.android_api) elif is_windows(): generate_windows_triplets(build_dir) + elif is_macOS(): + osx_target = args.apple_deploy_target + if args.apple_deploy_target is None: + osx_target = os.environ.get("MACOSX_DEPLOYMENT_TARGET") + if osx_target is not None: + log.info(f"Setting VCPKG_OSX_DEPLOYMENT_TARGET to {osx_target}") + generate_macos_triplets(build_dir, osx_target) else: - generate_posix_triplets(build_dir) + # Linux, *BSD, AIX or other platforms + generate_linux_triplets(build_dir) add_default_definition(cmake_extra_defines, "CMAKE_TOOLCHAIN_FILE", str(vcpkg_toolchain_path)) vcpkg_install_options = generate_vcpkg_install_options(build_dir, args) @@ -1592,8 +1600,11 @@ def generate_build_tree( raise BuildError("WebNN is only available for WebAssembly build.") cmake_args += ["-Donnxruntime_USE_WEBNN=ON"] - if args.use_jsep and args.use_webgpu: - raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") + # TODO: currently we allows building with both --use_jsep and --use_webgpu in this working branch. + # This situation is temporary. Eventually, those two flags will be mutually exclusive. + # + # if args.use_jsep and args.use_webgpu: + # raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") if args.use_external_dawn and not args.use_webgpu: raise BuildError("External Dawn (--use_external_dawn) must be enabled with WebGPU (--use_webgpu).") diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml new file mode 100644 index 0000000000000..8aaaa0e85585a --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -0,0 +1,142 @@ +parameters: +- name: CudaVersion + type: string + default: '12.2' + +- name: QnnSdk + displayName: QNN SDK Version + type: string + default: 2.31.0.250130 + +- name: IsReleaseBuild + displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. + type: boolean + default: false + +- name: PackageName + displayName: What is the package name? + type: string + default: 'Microsoft.ML.OnnxRuntime.Flamingo' + +variables: + - template: templates/common-variables.yml + - name: ReleaseVersionSuffix + value: '' + - name: win_cuda_home + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: $(Agent.TempDirectory)\v11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: $(Agent.TempDirectory)\v12.2 + +stages: + - template: templates/win-ci.yml + parameters: + ort_build_pool_name: 'onnxruntime-Win2022-GPU-A10' + DoCompliance: false + DoEsrp: true + stage_name_suffix: CUDA + buildArch: x64 + msbuildPlatform: x64 + packageName: x64-cuda + CudaVersion: ${{ parameters.CudaVersion }} + buildparameter: --use_cuda --cuda_home=${{ variables.win_cuda_home }} --enable_onnx_tests --enable_wcos --use_webgpu --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual" + runTests: false + buildJava: false + java_artifact_id: onnxruntime_gpu + UseIncreasedTimeoutForTests: false + SpecificArtifact: false + BuildId: '0' + + - template: templates/qnn-ep-win.yml + parameters: + qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QnnSdk: ${{ parameters.QnnSdk }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + DoEsrp: true + ArtifactName: 'drop-nuget-qnn-arm64' + # Add --use_webgpu to enable WebGPU + buildParameter: '--arm64' + buildPlatform: 'ARM64' + buildArch: 'ARM64' + StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' + build_config: 'RelWithDebInfo' + Is1ES: false + PublishArchive: true + + - stage: NugetPackaging + dependsOn: [Windows_Packaging_CUDA, OnnxRuntime_QNN_Nuget_Win_Arm64] + jobs: + - job: CreateNugetPackage + pool: 'Onnxruntime-Win2022-GPU-A10' + timeoutInMinutes: 120 + steps: + - checkout: self + clean: true + submodules: none + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - managed nuget' + inputs: + artifactName: 'drop-nuget-qnn-arm64' + targetPath: '$(Build.BinariesDirectory)/managed-nuget' + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - win-x64' + inputs: + artifactName: 'onnxruntime-win-x64-cuda' + targetPath: '$(Build.BinariesDirectory)/win-x64' + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - win-arm64' + inputs: + artifactName: 'onnxruntime-win-ARM64-qnn' + targetPath: '$(Build.BinariesDirectory)/win-arm64' + + - task: PowerShell@2 + displayName: 'Extract Nuget Package Version' + inputs: + targetType: 'inline' + script: | + $nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/managed-nuget -Filter Microsoft.ML.OnnxRuntime.Managed.*.nupkg -Recurse) + $package_name = $nupkgs[0].Name + $version_length = $package_name.Length - "Microsoft.ML.OnnxRuntime.Managed.".Length - ".nupkg".Length + $package_version = $package_name.Substring("Microsoft.ML.OnnxRuntime.Managed.".Length, $version_length) + Write-Host "##vso[task.setvariable variable=package_version;]$package_version" + workingDirectory: $(Build.BinariesDirectory) + + - task: PowerShell@2 + displayName: 'Extract Archives' + inputs: + targetType: 'inline' + script: | + Expand-Archive -Path $(Build.BinariesDirectory)/win-x64/onnxruntime-win-x64-cuda*.zip -DestinationPath $(Build.BinariesDirectory)/win-x64 + Expand-Archive -Path $(Build.BinariesDirectory)/win-arm64/onnxruntime-win-ARM64-qnn*.zip -DestinationPath $(Build.BinariesDirectory)/win-arm64 + $win_x64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-x64 -Filter onnxruntime-win-x64-cuda*)[0].FullName + $win_arm64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-arm64 -Filter onnxruntime-win-ARM64-qnn*)[0].FullName + Write-Host "##vso[task.setvariable variable=win_x64;]$win_x64" + Write-Host "##vso[task.setvariable variable=win_arm64;]$win_arm64" + workingDirectory: $(Build.BinariesDirectory) + + - task: PythonScript@0 + displayName: 'Generate Nuget Package' + inputs: + scriptPath: '$(Build.SourcesDirectory)/tools/nuget/generate_nuspec_for_custom_nuget.py' + arguments: '--nuspec_path "$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec" --root_dir "$(Build.SourcesDirectory)" --commit_id "$(Build.SourceVersion)" --win_arm64 "$(win_arm64)" --win_x64 "$(win_x64)" --package_version "$(package_version)" --package_name "${{ parameters.PackageName }}"' + + - task: NuGetCommand@2 + displayName: 'Pack Nuget Package' + inputs: + command: 'pack' + packagesToPack: '$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec' + packDestination: $(Build.ArtifactStagingDirectory)\ + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: Nuget' + inputs: + pathtoPublish: '$(Build.ArtifactStagingDirectory)' + artifactName: '${{ parameters.PackageName }}' diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index a0e49692220f9..7a78c6ba0fcdf 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -31,10 +31,12 @@ stages: machine_pool: vmImage: 'macOS-13' itemPattern: '*/*mac*x86_64.whl' + arch: 'x86_64' - template: templates/py-package-smoking-test.yml parameters: job_name: Test_LINUX_x86_64_Wheels itemPattern: '*/*manylinux*x86_64.whl' + arch: 'x86_64' machine_pool: name: 'onnxruntime-Ubuntu2204-AMD-CPU' diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 01d30d0e1ba86..28ddd29ec63e6 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -50,10 +50,10 @@ parameters: displayName: 'Linux packages cmake build type. Linux Only.' default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel # Only applies to QNN packages. - name: qnn_sdk_version @@ -63,17 +63,33 @@ parameters: trigger: none -stages: -- template: stages/py-cpu-packaging-stage.yml +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines parameters: - enable_linux_cpu: ${{ parameters.enable_linux_cpu }} - enable_windows_cpu: ${{ parameters.enable_windows_cpu }} - enable_mac_cpu: ${{ parameters.enable_mac_cpu }} - enable_linux_arm: ${{ parameters.enable_linux_arm }} - enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} - enable_windows_arm64ec_qnn: ${{ parameters.enable_windows_arm64ec_qnn }} - enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }} - enable_linux_x64_qnn: ${{ parameters.enable_linux_x64_qnn }} - build_py_parameters: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - qnn_sdk_version: ${{ parameters.qnn_sdk_version }} + sdl: + sourceAnalysisPool: + name: onnxruntime-Win-CPU-2022 + os: windows + stages: + - template: stages/py-cpu-packaging-stage.yml + parameters: + enable_linux_cpu: ${{ parameters.enable_linux_cpu }} + enable_windows_cpu: ${{ parameters.enable_windows_cpu }} + enable_mac_cpu: ${{ parameters.enable_mac_cpu }} + enable_linux_arm: ${{ parameters.enable_linux_arm }} + enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} + enable_windows_arm64ec_qnn: ${{ parameters.enable_windows_arm64ec_qnn }} + enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }} + enable_linux_x64_qnn: ${{ parameters.enable_linux_x64_qnn }} + build_py_parameters: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + qnn_sdk_version: ${{ parameters.qnn_sdk_version }} diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 055ef58e4524a..cfca998e0f06c 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -29,108 +29,58 @@ parameters: displayName: Pipeline BuildId, you could find it in the URL type: string default: '0' - -stages: - -- template: templates/qnn-ep-win.yml - parameters: - qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QnnSdk: ${{ parameters.QnnSdk }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-x64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' - build_config: ${{ parameters.build_config }} - -- template: templates/qnn-ep-win.yml +resources: + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines parameters: - qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QnnSdk: ${{ parameters.QnnSdk }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-arm64' - buildParameter: '--arm64' - buildPlatform: 'ARM64' - buildArch: 'ARM64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' - build_config: ${{ parameters.build_config }} - -- stage: NuGet_Packaging_QNN - pool: 'Onnxruntime-QNNEP-Windows-2022-CPU' - dependsOn: - - OnnxRuntime_QNN_Nuget_Win_x64 - - OnnxRuntime_QNN_Nuget_Win_Arm64 - condition: succeeded() - jobs: - - job: NuGet_Packaging_QNN - workspace: - clean: all - steps: - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - QNN NuGet x64' - inputs: - artifactName: 'drop-nuget-qnn-x64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-x64' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - QNN NuGet arm64' - inputs: - artifactName: 'drop-nuget-qnn-arm64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64' - - - task: PowerShell@2 - displayName: 'Bundle NuGet' - inputs: - targetType: 'inline' - script: | - - $x64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-x64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) - $nuget_package_name = $x64_nupkgs[0].Name - $x64_nuget_package = $x64_nupkgs[0].FullName - - $nupkg_unzipped_directory = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget_unzip_merged', [System.IO.Path]::GetFileNameWithoutExtension($nuget_package_name)) - - $x64_unzip_cmd = "7z.exe x $x64_nuget_package -y -o$nupkg_unzipped_directory" - Invoke-Expression -Command $x64_unzip_cmd - - $arm64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-arm64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) - $arm64_nuget_package = $arm64_nupkgs[0].FullName + sdl: + sourceAnalysisPool: + name: onnxruntime-Win-CPU-2022 + os: windows + stages: - $arm64_unzip_cmd = "7z.exe x $arm64_nuget_package -y -o$nupkg_unzipped_directory" - Invoke-Expression -Command $arm64_unzip_cmd - - $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget-artifact-merged') - if (!(Test-Path $merged_nuget_path)) { - New-Item -Path $merged_nuget_path -ItemType Directory - } - - $merged_zip = [System.IO.Path]::Combine($merged_nuget_path, 'qnn_nuget.zip') - $zip_cmd = "7z.exe a -r $merged_zip $nupkg_unzipped_directory/*" - Invoke-Expression -Command $zip_cmd - - $merged_nuget = [System.IO.Path]::Combine($merged_nuget_path, $nuget_package_name) - move $merged_zip $merged_nuget - workingDirectory: $(Build.BinariesDirectory) - - - template: templates/esrp_nuget.yml + - template: templates/qnn-ep-win.yml parameters: - DisplayName: 'ESRP - sign NuGet package' - FolderPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' + qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QnnSdk: ${{ parameters.QnnSdk }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} DoEsrp: ${{ parameters.DoEsrp }} + ArtifactName: 'drop-nuget-qnn-x64' + StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' + build_config: ${{ parameters.build_config }} - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-qnn' - targetPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' + - template: templates/qnn-ep-win.yml + parameters: + qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QnnSdk: ${{ parameters.QnnSdk }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + DoEsrp: ${{ parameters.DoEsrp }} + ArtifactName: 'drop-nuget-qnn-arm64' + buildParameter: '--arm64' + buildPlatform: 'ARM64' + buildArch: 'ARM64' + StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' + build_config: ${{ parameters.build_config }} + + - template: stages/nuget-qnn-packaging-stage.yml + parameters: + DoEsrp: ${{ parameters.DoEsrp }} -- template: templates/publish-nuget-steps.yml - parameters: - download_artifacts_steps: - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' - ArtifactName: 'drop-signed-nuget-qnn' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} + - template: templates/publish-nuget-steps.yml + parameters: + download_artifacts_steps: + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Signed NuGet Qnn Package' + ArtifactName: 'drop-signed-nuget-qnn' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact/final-package' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml deleted file mode 100644 index f7f5c7b1494e8..0000000000000 --- a/tools/ci_build/github/azure-pipelines/rocm-nuget-packaging-pipeline.yml +++ /dev/null @@ -1,339 +0,0 @@ -parameters: -- name: RunOnnxRuntimeTests - displayName: Run Tests? - type: boolean - default: true - -- name: UseIncreasedTimeoutForTests - displayName: Increase timeout for tests? Set it to false if you are doing an Onnx Runtime release. - type: boolean - default: false - -- name: DoCompliance - displayName: Run Compliance Tasks? - type: boolean - default: true - -- name: DoEsrp - displayName: Run code sign tasks? Must be true if you are doing an ONNX Runtime release - type: boolean - default: true - -- name: IsReleaseBuild - displayName: Is a release build? Set it to true if you are doing an ONNX Runtime release. - type: boolean - default: false - -- name: PreReleaseVersionSuffixString - displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. - type: string - values: - - alpha - - beta - - rc - - none - default: none - -- name: PreReleaseVersionSuffixNumber - displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. - type: number - default: 0 - -# these 2 parameters are used for debugging. -- name: SpecificArtifact - displayName: Use Specific Artifact (Debugging only) - type: boolean - default: false - -- name: BuildId - displayName: Pipeline BuildId, you could find it in the URL - type: string - default: '0' - -- name: NugetPackageSuffix - displayName: Suffix to append to nuget package - type: string - default: 'NONE' - -resources: - repositories: - - repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step - type: github - endpoint: ort-examples - name: microsoft/onnxruntime-inference-examples - - repository: manylinux - type: Github - endpoint: Microsoft - name: pypa/manylinux - ref: 5eda9aded5462201e6310105728d33016e637ea7 - -variables: -- name: ReleaseVersionSuffix - value: '' - -stages: -- template: stages/set_packaging_variables_stage.yml - parameters: - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} - PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - -# ROCm -- stage: Linux_C_API_Packaging_ROCm_x64 - dependsOn: [] - jobs: - - job: Linux_C_API_Packaging_ROCm_x64 - workspace: - clean: all - timeoutInMinutes: 480 - pool: onnxruntime-Ubuntu2204-AMD-CPU - variables: - RocmVersion: '6.2' - RocmVersionPatchSuffix: '' - steps: - - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - submodules: recursive - - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux, for get-docker-image-steps.yml - submodules: false - - # get-docker-image-steps.yml will move the $(Build.SourcesDirectory)/manylinux into $(Build.SourcesDirectory)/onnxruntime, - # then rename $(Build.SourcesDirectory)/onnxruntime as $(Build.SourcesDirectory) - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: >- - --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur - --build-arg BUILD_UID=$(id -u) - --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 - --build-arg ROCM_VERSION=$(RocmVersion)$(RocmVersionPatchSuffix) - --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root - --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: - --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib - Repository: onnxruntimetrainingrocmbuild-rocm$(RocmVersion) - CheckOutManyLinux: true - - - template: templates/set-version-number-variables-step.yml - - - task: Bash@3 - displayName: 'Build' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/build_rocm_c_api_package.sh - arguments: >- - -S $(Build.SourcesDirectory) - -B $(Build.BinariesDirectory) - -V $(RocmVersion) - -I onnxruntimetrainingrocmbuild-rocm$(RocmVersion) - -P python3.10 - - - script: | - set -e -x - mkdir $(Build.ArtifactStagingDirectory)/testdata - cp $(Build.BinariesDirectory)/Release/libcustom_op_library.so* $(Build.ArtifactStagingDirectory)/testdata - ls -al $(Build.ArtifactStagingDirectory) - displayName: 'Create Artifacts for CustomOp' # libcustom_op_library.so from cpu build is built with fp8, ROCm does not support it. - - - template: templates/c-api-artifacts-package-and-publish-steps-posix.yml - parameters: - buildConfig: 'Release' - artifactName: 'onnxruntime-linux-x64-rocm-$(OnnxRuntimeVersion)' - artifactNameNoVersionString: 'onnxruntime-linux-x64-rocm' - libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' - - - template: templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - template: templates/clean-agent-build-directory-step.yml - -- stage: NuGet_Packaging_ROCm - dependsOn: - - Setup - - Linux_C_API_Packaging_ROCm_x64 - condition: succeeded() - jobs: - - job: NuGet_Packaging_ROCm - workspace: - clean: all - # we need to use a 2022 pool to create the nuget package with MAUI targets. - # VS2019 has no support for net6/MAUI and we need to use msbuild (from the VS install) to do the packing - pool: 'Onnxruntime-Win-CPU-2022' - variables: - breakCodesignValidationInjection: ${{ parameters.DoEsrp }} - ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] - - steps: - - checkout: self - submodules: true - fetchDepth: 1 - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - NuGet' - ArtifactName: 'onnxruntime-linux-x64-rocm' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - task: PowerShell@2 - displayName: 'Reconstruct Build Directory' - inputs: - targetType: inline - script: | - Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tgz | % { - # *.tar will be created after *.tgz is extracted - $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\nuget-artifact" - Write-Output $cmd - Invoke-Expression -Command $cmd - } - - Get-ChildItem $(Build.BinariesDirectory)\nuget-artifact -Filter *.tar | % { - $cmd = "7z.exe x $($_.FullName) -y -o$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" - Write-Output $cmd - Invoke-Expression -Command $cmd - } - - $ort_dirs = Get-ChildItem -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory - foreach ($ort_dir in $ort_dirs) - { - $dirname = Split-Path -Path $ort_dir -Leaf - $dirname = $dirname.SubString(0, $dirname.LastIndexOf('-')) - Write-Output "Renaming $ort_dir to $dirname" - Rename-Item -Path $ort_dir -NewName $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname - } - - Copy-Item -Path $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-rocm\lib\* -Destination $(Build.BinariesDirectory)\RelWithDebInfo - - - script: | - tree /F - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Inspect Build Binaries Directory' - - - script: | - mklink /D /J models C:\local\models - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Create models link' - - - task: NuGetToolInstaller@0 - displayName: Use Nuget 6.10.x - inputs: - versionSpec: 6.10.x - - - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm"' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: MSBuild@1 - displayName: 'Build C# bindings' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: > - -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" - -p:OrtPackageId="Microsoft.ML.OnnxRuntime.ROCm" - -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - -p:IsLinuxBuild=true - -p:IsWindowsBuild=false - -p:IsMacOSBuild=false - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - template: templates/win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' - DisplayName: 'ESRP - Sign C# dlls' - DoEsrp: ${{ parameters.DoEsrp }} - - - task: UsePythonVersion@0 - displayName: 'Use Python' - inputs: - versionSpec: 3.12 - - - task: MSBuild@1 - displayName: 'Build Nuget Packages' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' - configuration: RelWithDebInfo - platform: 'Any CPU' - msbuildArguments: > - -t:CreatePackage - -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" - -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm - -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - -p:CurrentTime=$(BuildTime) - -p:CurrentDate=$(BuildDate) - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - Contents: '*.snupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - template: templates/esrp_nuget.yml - parameters: - DisplayName: 'ESRP - sign NuGet package' - FolderPath: '$(Build.ArtifactStagingDirectory)' - DoEsrp: ${{ parameters.DoEsrp }} - - - template: templates/validate-package.yml - parameters: - PackageType: 'nuget' - PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' - PlatformsSupported: 'linux-x64' - VerifyNugetSigning: false - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-ROCm' - targetPath: '$(Build.ArtifactStagingDirectory)' - - - task: MSBuild@1 - displayName: 'Clean C#' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - -- template: nuget/templates/test_linux.yml - parameters: - AgentPool: AMD-GPU - ArtifactSuffix: 'ROCm' - StageSuffix: 'ROCm' - NugetPackageName: 'Microsoft.ML.OnnxRuntime.ROCm' - SpecificArtifact: ${{ parameters.specificArtifact }} - CustomOpArtifactName: 'onnxruntime-linux-x64-rocm' - BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml b/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml deleted file mode 100644 index 1d2393d8f96d5..0000000000000 --- a/tools/ci_build/github/azure-pipelines/rocm-publish-nuget-pipeline.yml +++ /dev/null @@ -1,21 +0,0 @@ -resources: - pipelines: - - pipeline: build - source: 'Nuget ROCM Packaging pipeline' - trigger: - branches: - include: - - main - - rel-* - branch: main - -# ROCm -stages: -- template: templates/publish-nuget-steps.yml - parameters: - stage_name: 'Publish_ROCM_NuGet_Package' - download_artifacts_steps: - - download: build - displayName: 'Download Pipeline Artifact - Signed NuGet Package' - artifact: 'drop-signed-nuget-ROCm' - - script: move "$(Pipeline.Workspace)\build\drop-signed-nuget-ROCm\*" $(Build.BinariesDirectory)\nuget-artifact\final-package diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index 8fabb80a73869..5ae60aac8f9b4 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -96,18 +96,10 @@ stages: inputs: versionSpec: 6.10.x - - task: PowerShell@2 - displayName: Install MAUI workloads - inputs: - targetType: 'inline' - script: | - dotnet workload install android ios maccatalyst - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: MSBuild@1 displayName: 'Restore NuGet Packages and create project.assets.json' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu"' @@ -116,7 +108,7 @@ stages: - task: MSBuild@1 displayName: 'Build C# bindings' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' configuration: RelWithDebInfo platform: 'Any CPU' msbuildArguments: > @@ -208,7 +200,7 @@ stages: - task: MSBuild@1 displayName: 'Clean C#' inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' platform: 'Any CPU' configuration: RelWithDebInfo msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' @@ -223,4 +215,3 @@ stages: inputs: artifactName: 'drop-signed-nuget-GPU' targetPath: '$(Build.ArtifactStagingDirectory)' - diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml new file mode 100644 index 0000000000000..03802746cec3d --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-qnn-packaging-stage.yml @@ -0,0 +1,76 @@ +parameters: +- name: DoEsrp + displayName: Run code sign tasks? Must be true if you are doing an Onnx Runtime release. + type: boolean + default: true + +stages: +- stage: NuGet_Packaging_QNN + pool: + name: 'Onnxruntime-QNNEP-Windows-2022-CPU' + dependsOn: + - OnnxRuntime_QNN_Nuget_Win_x64 + - OnnxRuntime_QNN_Nuget_Win_Arm64 + condition: succeeded() + jobs: + - job: NuGet_Packaging_QNN + workspace: + clean: all + steps: + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - QNN NuGet x64' + inputs: + artifactName: 'drop-nuget-qnn-x64' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact-x64' + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - QNN NuGet arm64' + inputs: + artifactName: 'drop-nuget-qnn-arm64' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact-arm64' + + - task: PowerShell@2 + displayName: 'Bundle NuGet' + inputs: + targetType: 'inline' + script: | + + $x64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-x64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) + $nuget_package_name = $x64_nupkgs[0].Name + $x64_nuget_package = $x64_nupkgs[0].FullName + + $nupkg_unzipped_directory = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget_unzip_merged', [System.IO.Path]::GetFileNameWithoutExtension($nuget_package_name)) + + $x64_unzip_cmd = "7z.exe x $x64_nuget_package -y -o$nupkg_unzipped_directory" + Invoke-Expression -Command $x64_unzip_cmd + + $arm64_nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/nuget-artifact-arm64 -Filter Microsoft.ML.OnnxRuntime.QNN*.nupkg -Recurse) + $arm64_nuget_package = $arm64_nupkgs[0].FullName + + $arm64_unzip_cmd = "7z.exe x $arm64_nuget_package -y -o$nupkg_unzipped_directory" + Invoke-Expression -Command $arm64_unzip_cmd + + $merged_nuget_path = [System.IO.Path]::Combine($Env:BUILD_ARTIFACTSTAGINGDIRECTORY, 'nuget-artifact-merged') + if (!(Test-Path $merged_nuget_path)) { + New-Item -Path $merged_nuget_path -ItemType Directory + } + + $merged_zip = [System.IO.Path]::Combine($merged_nuget_path, 'qnn_nuget.zip') + $zip_cmd = "7z.exe a -r $merged_zip $nupkg_unzipped_directory/*" + Invoke-Expression -Command $zip_cmd + + $merged_nuget = [System.IO.Path]::Combine($merged_nuget_path, $nuget_package_name) + move $merged_zip $merged_nuget + workingDirectory: $(Build.BinariesDirectory) + + - template: ../templates/esrp_nuget.yml + parameters: + DisplayName: 'ESRP - sign NuGet package' + FolderPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' + DoEsrp: ${{ parameters.DoEsrp }} + + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Pipeline NuGet Artifact' + inputs: + artifactName: 'drop-signed-nuget-qnn' + targetPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 4ff539df9f914..5e783607e3622 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -123,7 +123,7 @@ stages: --skip_submodule_sync --cmake_generator "Visual Studio 17 2022" --enable_pybind - --enable_onnx_tests + --enable_onnx_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache ${{ parameters.build_py_parameters }} --parallel --use_binskim_compliant_compile_flags --update --build $(TelemetryOption) @@ -151,10 +151,11 @@ stages: Contents: '*.whl' TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: PublishBuildArtifacts@1 + - task: 1ES.PublishPipelineArtifact@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - ArtifactName: onnxruntime + artifactName: onnxruntime-win-$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' - script: | 7z x *.whl @@ -199,7 +200,9 @@ stages: workspace: clean: all pool: - vmImage: 'macOS-13' + name: "Azure Pipelines" + image: "macOS-13" + os: macOS variables: MACOSX_DEPLOYMENT_TARGET: '13.3' strategy: @@ -251,74 +254,81 @@ stages: Contents: '*.whl' TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: PublishBuildArtifacts@1 + - task: 1ES.PublishPipelineArtifact@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - ArtifactName: onnxruntime + artifactName: onnxruntime-macos-$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' - template: ../templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' - - ${{ if eq(parameters.enable_linux_arm, true) }}: - - stage: Python_Packaging_Linux_ARM - dependsOn: [] - jobs: - - template: ../templates/py-linux.yml - parameters: - arch: 'aarch64' - machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' - extra_build_arg: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} - - - ${{ if eq(parameters.enable_linux_cpu, true) }}: - - stage: Python_Packaging_Linux_CPU - dependsOn: [] - jobs: +- ${{ if eq(parameters.enable_linux_arm, true) }}: + - stage: Python_Packaging_Linux_ARM + dependsOn: [] + jobs: - template: ../templates/py-linux.yml parameters: - arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' + arch: 'aarch64' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} + is1ES: true - - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: - - stage: Python_Packaging_Windows_ARM64_QNN - dependsOn: [] - jobs: - - template: ../templates/py-win-arm64-qnn.yml +- ${{ if eq(parameters.enable_linux_cpu, true) }}: + - stage: Python_Packaging_Linux_CPU + dependsOn: [] + jobs: + - template: ../templates/py-linux.yml + parameters: + arch: 'x86_64' + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + is1ES: true + +- ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: + - stage: Python_Packaging_Windows_ARM64_QNN + dependsOn: [] + jobs: + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + is1ES: true + +- ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: + - stage: Python_Packaging_Windows_arm64ec_QNN + dependsOn: [] + jobs: + - template: ../templates/py-win-arm64ec-qnn.yml parameters: - MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + is1ES: true - - ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: - - stage: Python_Packaging_Windows_arm64ec_QNN - dependsOn: [] - jobs: - - template: ../templates/py-win-arm64ec-qnn.yml - parameters: - MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QNN_SDK: ${{ parameters.qnn_sdk_version }} - BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - - - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: - - stage: Python_Packaging_Windows_x64_QNN - dependsOn: [] - jobs: - - template: ../templates/py-win-x64-qnn.yml - parameters: - MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QNN_SDK: ${{ parameters.qnn_sdk_version }} - BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - - - ${{ if eq(parameters.enable_linux_x64_qnn, true) }}: - - stage: Python_Packaging_Linux_x64_QNN - dependsOn: [] - jobs: - - template: ../templates/py-linux-qnn.yml +- ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: + - stage: Python_Packaging_Windows_x64_QNN + dependsOn: [] + jobs: + - template: ../templates/py-win-x64-qnn.yml parameters: - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' - extra_build_arg: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} + MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + is1ES: true + +- ${{ if eq(parameters.enable_linux_x64_qnn, true) }}: + - stage: Python_Packaging_Linux_x64_QNN + dependsOn: [] + jobs: + - template: ../templates/py-linux-qnn.yml + parameters: + machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + is1ES: true diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml index 5ee425405ac70..e1a514ea54123 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml @@ -57,6 +57,22 @@ steps: copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.pdb $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.lib $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + # Copy WebGPU dependencies if required + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\dxcompiler.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\dxil.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + + # Copy QNN dependencies if required + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_qnn.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\libQnnHtp*.so $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib /Y + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\libqnnhtp*.cat $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib /Y + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnCpu.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtp.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpPrepare.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpV68Stub.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpV73Stub.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnSaver.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnSystem.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + # copy trt ep libraries only when trt ep is enabled copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_tensorrt.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_tensorrt.pdb $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml new file mode 100644 index 0000000000000..f6956b426ddfc --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_openvino.yml @@ -0,0 +1,64 @@ +parameters: + - name: OpenVINOVersion + type: string + default: '2025.0.0' + +steps: + - powershell: | + $Url = "https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.0/windows/openvino_toolkit_windows_2025.0.0.17942.1f68be9f594_x86_64.zip" + $OutputPath = "$env:Agent_TempDirectory\openvino.zip" + $ExtractPath = "$env:Agent_TempDirectory\openvino-v$env:OpenVINOVersion" + $TempExtractPath = "$env:Agent_TempDirectory\openvino_temp" + + # Ensure directories exist + if (Test-Path $ExtractPath) { + Remove-Item -Recurse -Force $ExtractPath + } + New-Item -ItemType Directory -Path $ExtractPath | Out-Null + New-Item -ItemType Directory -Path $TempExtractPath | Out-Null + + # Download OpenVINO ZIP + Write-Output "Downloading OpenVINO" + Invoke-WebRequest -Uri $Url -OutFile $OutputPath + + # Extract to temporary directory first + Write-Output "Extracting OpenVINO to a temporary directory" + Expand-Archive -Path $OutputPath -DestinationPath $TempExtractPath -Force + + # Locate the nested subdirectory + $InnerFolder = Get-ChildItem -Path $TempExtractPath -Directory | Select-Object -First 1 + + if ($InnerFolder) { + Write-Output "Moving extracted files to final destination" + Move-Item -Path "$($InnerFolder.FullName)\*" -Destination $ExtractPath -Force + } else { + Write-Error "Extraction failed: No expected subdirectory found in $TempExtractPath." + Write-Error "The archive may not have extracted correctly, or its structure is different than expected." + exit 1 + } + + # Clean up temporary files + Remove-Item -Recurse -Force $TempExtractPath + Remove-Item -Force $OutputPath + + # Confirm success + Write-Output "OpenVINO extracted to $ExtractPath" + displayName: 'Download OpenVINO Toolkit v${{ parameters.OpenVINOVersion }}' + env: + OpenVINOVersion: ${{ parameters.OpenVINOVersion }} + + - powershell: | + echo "##vso[task.setvariable variable=OpenVINORootDir]$(Agent.TempDirectory)\openvino-v${{ parameters.OpenVINOVersion }}" + displayName: 'Set OpenVINORootDir' + + - task: CmdLine@2 + inputs: + script: | + echo $(OpenVINORootDir) + displayName: 'Print OpenVINORootDir after downloading OpenVINO' + + - task: CmdLine@2 + displayName: 'Print contents of OpenVINO Toolkit' + inputs: + script: | + dir $(OpenVINORootDir) diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml index a4d5a73118ea2..2b73f82615bba 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml @@ -1,4 +1,8 @@ steps: +- task: NodeTool@0 + inputs: + # requires Node.js v22 for float16 testing (the V8 flag "--js-float16array") + versionSpec: '22.x' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)/js' @@ -11,6 +15,10 @@ steps: npm test workingDirectory: '$(Build.SourcesDirectory)/js/common' displayName: 'run onnxruntime-common tests' +- script: | + npm run test:f16 + workingDirectory: '$(Build.SourcesDirectory)/js/common' + displayName: 'run onnxruntime-common tests (enable Float16Array)' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)/js/web' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index 347a3145e8c70..8126cda449daa 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -6,10 +6,10 @@ parameters: type: string default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel - name: device type: string @@ -27,68 +27,82 @@ parameters: displayName: QNN SDK version type: string default: 2.31.0.250130 + +- name: is1ES + displayName: 'Whether the pipeline is running in 1ES' + type: boolean + default: false jobs: - job: Linux_py_qnn_Wheels_x64 timeoutInMinutes: 240 workspace: clean: all - pool: ${{ parameters.machine_pool }} + pool: + name: ${{ parameters.machine_pool }} + os: linux variables: - # The build machine pool doesn't have dotnet, so it can't run CG. - - name: skipComponentGovernanceDetection - value: true - - name: ORT_CACHE_DIR - value: $(Agent.TempDirectory)/ort_ccache - - name: TODAY - value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - - name: extra_build_args - ${{ if ne(parameters.extra_build_arg, '') }}: - value: -x ${{ parameters.extra_build_arg }} - ${{ if eq(parameters.extra_build_arg, '') }}: - value: '' + # The build machine pool doesn't have dotnet, so it can't run CG. + - name: skipComponentGovernanceDetection + value: true + - name: ORT_CACHE_DIR + value: $(Agent.TempDirectory)/ort_ccache + - name: TODAY + value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: -x ${{ parameters.extra_build_arg }} + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' steps: - - checkout: self - clean: true - submodules: none + - checkout: self + clean: true + submodules: none - - template: jobs/download_linux_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QnnSdk }} + - template: jobs/download_linux_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QnnSdk }} - - template: set-nightly-build-option-variable-step.yml + - template: set-nightly-build-option-variable-step.yml - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecpubuildpythonx86_64_qnn + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile + Context: tools/ci_build/github/linux/docker/inference/x86_64/python/cpu + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimecpubuildpythonx86_64_qnn - - template: linux-build-step-with-cache.yml - parameters: - WithCache: ${{parameters.with_cache}} - Today: $(TODAY) - AdditionalKey: Linux_py_qnn_Wheels_x64 - CacheDir: $(ORT_CACHE_DIR) - ChangeEveryCommit: true - BuildStep: - - task: Bash@3 - displayName: 'Build Python Wheel' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpythonx86_64_qnn -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) - env: - ADDITIONAL_DOCKER_PARAMETER: "--volume $(QnnSDKRootDir):/qnn_sdk" + - template: linux-build-step-with-cache.yml + parameters: + WithCache: ${{parameters.with_cache}} + Today: $(TODAY) + AdditionalKey: Linux_py_qnn_Wheels_x64 + CacheDir: $(ORT_CACHE_DIR) + ChangeEveryCommit: true + BuildStep: + - task: Bash@3 + displayName: 'Build Python Wheel' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh + arguments: -i onnxruntimecpubuildpythonx86_64_qnn -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) + env: + ADDITIONAL_DOCKER_PARAMETER: "--volume $(QnnSDKRootDir):/qnn_sdk" + - ${{ if eq(parameters.is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Artifact: Linux ONNXRuntime QNN python wheel' + inputs: + targetPath: '$(Build.BinariesDirectory)/dist' + artifactName: onnxruntime-linux-qnn-x64 - - task: PublishBuildArtifacts@1 + - ${{ if eq(parameters.is1ES, false) }}: + - task: PublishPipelineArtifact@1 displayName: 'Publish Artifact: Linux ONNXRuntime QNN python wheel' inputs: - PathtoPublish: '$(Build.BinariesDirectory)/dist' - ArtifactName: onnxruntime-linux-qnn-x64 + targetPath: '$(Build.BinariesDirectory)/dist' + artifactName: onnxruntime-linux-qnn-x64 - - template: component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml index e591b719ecfa9..8d0c4334f4874 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml @@ -9,10 +9,10 @@ parameters: type: string default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel - name: device type: string @@ -34,76 +34,98 @@ parameters: type: string default: '' +- name: is1ES + displayName: 'Whether the pipeline is running in 1ES' + type: boolean + default: false + jobs: - job: Linux_py_Wheels_${{ parameters.arch }}_${{parameters.ep}} timeoutInMinutes: 240 workspace: clean: all - pool: ${{ parameters.machine_pool }} + pool: + name: ${{ parameters.machine_pool }} + os: 'linux' + ${{ if eq(parameters.arch, 'aarch64') }}: + hostArchitecture: Arm64 variables: - # The build machine pool doesn't have dotnet, so it can't run CG. - - name: skipComponentGovernanceDetection - value: true - - name: ORT_CACHE_DIR - value: $(Agent.TempDirectory)/ort_ccache - - name: TODAY - value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - - name: extra_build_args - ${{ if ne(parameters.extra_build_arg, '') }}: - value: '-x ${{ parameters.extra_build_arg }}' - ${{ if eq(parameters.extra_build_arg, '') }}: - value: '' - - name: python_exe_path - ${{ if ne(parameters.python_exe_path, '') }}: - value: '-p ${{ parameters.python_exe_path }}' - ${{ if eq(parameters.python_exe_path, '') }}: - value: '' + # The build machine pool doesn't have dotnet, so it can't run CG. + - name: skipComponentGovernanceDetection + value: true + - name: ORT_CACHE_DIR + value: $(Agent.TempDirectory)/ort_ccache + - name: TODAY + value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: '-x ${{ parameters.extra_build_arg }}' + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' + - name: python_exe_path + ${{ if ne(parameters.python_exe_path, '') }}: + value: '-p ${{ parameters.python_exe_path }}' + ${{ if eq(parameters.python_exe_path, '') }}: + value: '' steps: - - checkout: self - clean: true - submodules: none - - - template: set-nightly-build-option-variable-step.yml - - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecpubuildpython${{ parameters.arch }} - - - template: linux-build-step-with-cache.yml - parameters: - WithCache: ${{parameters.with_cache}} - Today: $(TODAY) - AdditionalKey: Linux_py_Wheels_${{ parameters.arch }} - CacheDir: $(ORT_CACHE_DIR) - ChangeEveryCommit: true - BuildStep: - - task: Bash@3 - displayName: 'Build Python Wheel' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) $(python_exe_path) - ${{ if eq(parameters.with_cache, 'true') }}: - env: - ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" - - - task: PublishBuildArtifacts@1 + - checkout: self + clean: true + submodules: none + + - template: set-nightly-build-option-variable-step.yml + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu/Dockerfile + Context: tools/ci_build/github/linux/docker/inference/${{ parameters.arch }}/python/cpu + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimecpubuildpython${{ parameters.arch }} + + - template: linux-build-step-with-cache.yml + parameters: + WithCache: ${{parameters.with_cache}} + Today: $(TODAY) + AdditionalKey: Linux_py_Wheels_${{ parameters.arch }} + CacheDir: $(ORT_CACHE_DIR) + ChangeEveryCommit: true + BuildStep: + - task: Bash@3 + displayName: 'Build Python Wheel' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh + arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) $(python_exe_path) + ${{ if eq(parameters.with_cache, 'true') }}: + env: + ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" + + - ${{ if eq(parameters.is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - PathtoPublish: '$(Build.BinariesDirectory)/dist' - ArtifactName: onnxruntime-${{ parameters.ep }} - - - task: PublishPipelineArtifact@0 + targetPath: '$(Build.BinariesDirectory)/dist' + artifactName: onnxruntime-${{ parameters.arch }}-${{ parameters.ep }} + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Test Binaries' + inputs: + artifactName: 'drop-linux-cpu-${{ parameters.arch }}-${{ parameters.ep }}' + targetPath: '$(Build.BinariesDirectory)/${{ parameters.cmake_build_type }}' + - ${{ if eq(parameters.is1ES, false) }}: + - task: PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + targetPath: '$(Build.BinariesDirectory)/dist' + artifactName: onnxruntime-${{ parameters.arch }}-${{ parameters.ep }} + - task: PublishPipelineArtifact@1 displayName: 'Publish Test Binaries' inputs: artifactName: 'drop-linux-cpu-${{ parameters.arch }}-${{ parameters.ep }}' targetPath: '$(Build.BinariesDirectory)/${{ parameters.cmake_build_type }}' - - template: component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' + + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index 3a3da0f8f5afa..c0bd740b2d483 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -9,9 +9,13 @@ parameters: - name: machine_pool type: object -- name: python_arch +- name: ep type: string - default: 'x64' + default: 'cpu' + +- name: arch + type: string + default: 'x86_64' jobs: - job: ${{ parameters.job_name }} @@ -37,10 +41,9 @@ jobs: displayName: 'Use Python' inputs: versionSpec: $(PythonVersion) - architecture: ${{ parameters.python_arch }} - download: build # pipeline resource identifier. - artifact: 'onnxruntime' + artifact: 'onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}' - task: Bash@3 inputs: @@ -51,7 +54,7 @@ jobs: FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime" $PYTHON_PACKAGE_NAME + python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}" $PYTHON_PACKAGE_NAME python3 -m pip show $PYTHON_PACKAGE_NAME python3 -c "import onnxruntime as ort; print(ort.__version__)" workingDirectory: $(Pipeline.Workspace)/build/onnxruntime diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml index c475feaef0018..eef97341b8d53 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml @@ -19,10 +19,10 @@ parameters: type: string default: 'Release' values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel + - Debug + - Release + - RelWithDebInfo + - MinSizeRel - name: timeout type: number @@ -50,29 +50,31 @@ jobs: artifact: 'drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}' - download: current # pipeline resource identifier. - artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}-${{ parameters.ep }}' + artifact: 'onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}' - bash: | set -e -x mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} - mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}-${{parameters.ep}}" "$(Build.BinariesDirectory)/whl" + mv "$(Pipeline.Workspace)/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + displayName: 'Move the artifacts to the binaries directory' # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - download: build # pipeline resource identifier. artifact: 'drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}' - download: build # pipeline resource identifier. - artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}-${{ parameters.ep }}' + artifact: 'onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}' - bash: | set -e -x ls $(Pipeline.Workspace)/build mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}-${{parameters.ep}}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} - mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}-${{parameters.ep}}" "$(Build.BinariesDirectory)/whl" + mv "$(Pipeline.Workspace)/build/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}" "$(Build.BinariesDirectory)/whl" cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + displayName: 'Move the artifacts to the binaries directory' # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet - ${{ if eq(parameters.arch, 'x86_64') }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 4c9d0dccaf48d..10ea7f6203bb1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -19,6 +19,11 @@ parameters: type: string default: '' +- name: is1ES + displayName: 'Whether the pipeline is running in 1ES' + type: boolean + default: false + jobs: - job: Win_py_arm64_qnn_Wheels timeoutInMinutes: 210 @@ -26,6 +31,8 @@ jobs: clean: all pool: name: ${{ parameters.MACHINE_POOL }} + os: windows + hostArchitecture: Arm64 strategy: matrix: Python311_arm64: @@ -41,132 +48,140 @@ jobs: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - - checkout: self - clean: true - submodules: recursive - - - template: telemetry-steps.yml - - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - displayName: Copy python $(PythonVersion) version to agent tools directory - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: 'arm64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - task: onebranch.pipeline.tsaoptions@1 - displayName: 'OneBranch TSAOptions' - inputs: - tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' - appendSourceBranchName: false - - - task: PythonScript@0 - inputs: - scriptSource: inline - script: | - import subprocess - subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel']) - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Install python modules' - - - template: set-nightly-build-option-variable-step.yml - - - template: jobs/download_win_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QNN_SDK }} - - - task: PythonScript@0 - displayName: 'Generate cmake config' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: > - --config RelWithDebInfo - --build_dir $(Build.BinariesDirectory) - --skip_submodule_sync - --cmake_generator "$(VSGenerator)" - --build_shared_lib - --use_qnn - --qnn_home $(QnnSDKRootDir) - --enable_pybind - --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update - $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} - workingDirectory: '$(Build.BinariesDirectory)' - - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' - platform: 'arm64' - configuration: RelWithDebInfo - msbuildArchitecture: 'arm64' - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - createLogFile: true - - # Esrp signing - - template: win-esrp-dll.yml - parameters: - FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' - DisplayName: 'ESRP - Sign Native dlls' - DoEsrp: true - Pattern: '*.pyd' - - - task: PythonScript@0 - displayName: 'Build wheel' - inputs: - scriptPath: '$(Build.SourcesDirectory)\setup.py' - arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' - workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' - Contents: '*.whl' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - ArtifactName: onnxruntime_qnn_arm64 - - - script: | - 7z x *.whl - workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'unzip the package' - - - task: CredScan@3 - displayName: 'Run CredScan' - inputs: - debugMode: false - continueOnError: true - - - task: BinSkim@4 - displayName: 'Run BinSkim' - inputs: - AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 + XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ + COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete + DIR $(Agent.ToolsDirectory)\Python + DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) + DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 + displayName: Copy python $(PythonVersion) version to agent tools directory + + - task: UsePythonVersion@0 + inputs: + versionSpec: $(PythonVersion) + addToPath: true + architecture: 'arm64' + + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel']) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: set-nightly-build-option-variable-step.yml + + - template: jobs/download_win_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QNN_SDK }} + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --build_shared_lib + --use_qnn + --qnn_home $(QnnSDKRootDir) + --enable_pybind + --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'arm64' + configuration: RelWithDebInfo + msbuildArchitecture: 'arm64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - ${{ if eq(parameters.is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + artifactName: onnxruntime_qnn_arm64_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' + - ${{ if eq(parameters.is1ES, false) }}: + - task: PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + input: + artifactName: onnxruntime_qnn_arm64_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index ed29f1e67515e..24321d2a3e1ec 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -19,6 +19,11 @@ parameters: type: string default: '' +- name: is1ES + displayName: 'Whether the pipeline is running in 1ES' + type: boolean + default: false + jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 @@ -26,6 +31,7 @@ jobs: clean: all pool: name: ${{ parameters.MACHINE_POOL }} + os: windows strategy: matrix: Python310_x64: @@ -40,117 +46,124 @@ jobs: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - - checkout: self - clean: true - submodules: recursive - - - template: telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: 'x64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - task: onebranch.pipeline.tsaoptions@1 - displayName: 'OneBranch TSAOptions' - inputs: - tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' - appendSourceBranchName: fals - - - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\linux\python\requirements.txt - - - - template: set-nightly-build-option-variable-step.yml - - - template: jobs/download_win_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QNN_SDK }} - - - task: PythonScript@0 - displayName: 'Generate cmake config' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: > - --config RelWithDebInfo - --build_dir $(Build.BinariesDirectory) - --skip_submodule_sync - --cmake_generator "$(VSGenerator)" - --build_shared_lib - --use_qnn - --qnn_home $(QnnSDKRootDir) - --enable_pybind - --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update --arm64ec - $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} - workingDirectory: '$(Build.BinariesDirectory)' - - - task: VSBuild@1 - displayName: 'Build' - inputs: - solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' - platform: 'arm64ec' - configuration: RelWithDebInfo - msbuildArchitecture: 'x64' - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - createLogFile: true - - # Esrp signing - - template: win-esrp-dll.yml - parameters: - FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' - DisplayName: 'ESRP - Sign Native dlls' - DoEsrp: true - Pattern: '*.pyd' - - - task: PythonScript@0 - displayName: 'Build wheel' - inputs: - scriptPath: '$(Build.SourcesDirectory)\setup.py' - arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' - workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' - Contents: '*.whl' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - ArtifactName: onnxruntime_qnn_arm64ec - - - script: | - 7z x *.whl - workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'unzip the package' - - - task: CredScan@3 - displayName: 'Run CredScan' - inputs: - debugMode: false - continueOnError: true - - - task: BinSkim@4 - displayName: 'Run BinSkim' - inputs: - AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - - template: component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - task: UsePythonVersion@0 + inputs: + versionSpec: $(PythonVersion) + addToPath: true + architecture: 'x64' + + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: fals + + - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\linux\python\requirements.txt + + + - template: set-nightly-build-option-variable-step.yml + + - template: jobs/download_win_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QNN_SDK }} + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --build_shared_lib + --use_qnn + --qnn_home $(QnnSDKRootDir) + --enable_pybind + --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update --arm64ec + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'arm64ec' + configuration: RelWithDebInfo + msbuildArchitecture: 'x64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - ${{ if eq(parameters.is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + artifactName: onnxruntime_qnn_arm64ec_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' + - ${{ if eq(parameters.is1ES, false) }}: + - task: PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + artifactName: onnxruntime_qnn_arm64ec_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 13069846da342..175b343e55d57 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -19,6 +19,11 @@ parameters: type: string default: '' +- name: is1ES + displayName: 'Whether the pipeline is running in 1ES' + type: boolean + default: false + jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 @@ -116,10 +121,18 @@ jobs: Contents: '*.whl' TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: PublishBuildArtifacts@1 - displayName: 'Publish Artifact: ONNXRuntime python wheel' - inputs: - ArtifactName: onnxruntime_qnn_x64 + - ${{ if eq(parameters.is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + artifactName: onnxruntime_qnn_x64_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' + - ${{ if eq(parameters.is1ES, false) }}: + - task: PublishPipelineArtifact@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + artifactName: onnxruntime_qnn_x64_$(PythonVersion) + targetPath: '$(Build.ArtifactStagingDirectory)' - script: | 7z x *.whl diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index a93d6b5ff8419..3fa4799ec9c0e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -10,6 +10,8 @@ parameters: buildPlatform: 'x64' buildArch: 'x64' StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' + Is1ES: true + PublishArchive: false stages: - stage: ${{ parameters.StageName }} @@ -18,7 +20,8 @@ stages: - job: ${{ parameters.StageName }} timeoutInMinutes: 120 - pool: ${{ parameters.qnn_ep_build_pool_name }} + pool: + name: ${{ parameters.qnn_ep_build_pool_name }} variables: ${{ if eq(parameters.buildArch, 'ARM64') }}: targetArchitecture: 'arm64' @@ -28,133 +31,148 @@ stages: commonBuildArgs: '--update --compile_no_warning_as_error --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_binskim_compliant_compile_flags ${{ parameters.buildParameter }} ' steps: - - template: set-version-number-variables-step.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - - - template: jobs/download_win_qnn_sdk.yml + - template: set-version-number-variables-step.yml + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + + - template: jobs/download_win_qnn_sdk.yml + parameters: + QnnSDKVersion: ${{ parameters.QnnSdk }} + + - task: PythonScript@0 + displayName: 'Generate project' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '--use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' + + - task: VSBuild@1 + displayName: 'Build onnxruntime' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: VSBuild@1 + displayName: 'Build onnx_test_runner' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnx_test_runner.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: VSBuild@1 + displayName: 'Build onnxruntime_perf_test' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_perf_test.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: VSBuild@1 + displayName: 'Build onnxruntime_test_all (to copy Qnn libs)' + inputs: + solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_test_all.vcxproj' + platform: ${{ parameters.buildPlatform }} + configuration: ${{ parameters.build_config }} + msbuildArchitecture: ${{ parameters.buildArch }} + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' + createLogFile: true + + - task: CmdLine@2 + displayName: 'Print contents of binaries directory' + inputs: + script: | + dir $(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }} + + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' + DisplayName: 'ESRP - Sign dlls' + DoEsrp: ${{ parameters.DoEsrp }} + Pattern: 'onnxruntime*.dll' + + - ${{ if eq(parameters.PublishArchive, true) }}: + - template: c-api-artifacts-package-and-publish-steps-windows.yml parameters: - QnnSDKVersion: ${{ parameters.QnnSdk }} - - - task: PythonScript@0 - displayName: 'Generate project' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' - - - task: VSBuild@1 - displayName: 'Build onnxruntime' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnx_test_runner' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnx_test_runner.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnxruntime_perf_test' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_perf_test.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnxruntime_test_all (to copy Qnn libs)' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_test_all.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: CmdLine@2 - displayName: 'Print contents of binaries directory' - inputs: - script: | - dir $(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }} + buildConfig: ${{ parameters.build_config }} + artifactName: 'onnxruntime-win-${{ parameters.buildPlatform }}-qnn' + artifactNameNoVersionString: 'onnxruntime-win-${{ parameters.buildPlatform }}-qnn' + DoEsrp: ${{ parameters.DoEsrp }} + - task: MSBuild@1 + displayName: 'Restore NuGet Packages and create project.assets.json' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' + platform: 'Any CPU' + configuration: ${{ parameters.build_config }} + msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: MSBuild@1 + displayName: 'Build C# bindings' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' + platform: 'Any CPU' + configuration: ${{ parameters.build_config }} + msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - ${{ if eq(parameters.DoEsrp, true) }}: - template: win-esrp-dll.yml parameters: - FolderPath: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' - DisplayName: 'ESRP - Sign dlls' + FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\${{ parameters.build_config }}' + DisplayName: 'ESRP - Sign C# dlls' DoEsrp: ${{ parameters.DoEsrp }} - Pattern: 'onnxruntime*.dll' - - - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' - platform: 'Any CPU' - configuration: ${{ parameters.build_config }} - msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: MSBuild@1 - displayName: 'Build C# bindings' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' - platform: 'Any CPU' - configuration: ${{ parameters.build_config }} - msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - ${{ if eq(parameters.DoEsrp, true) }}: - - template: win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\${{ parameters.build_config }}' - DisplayName: 'ESRP - Sign C# dlls' - DoEsrp: ${{ parameters.DoEsrp }} - - - task: MSBuild@1 - displayName: 'Build Nuget Packages' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' - platform: 'Any CPU' - configuration: ${{ parameters.build_config }} - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=$(targetArchitecture)' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: CopyFiles@2 - displayName: 'Copy native nuget package to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - task: CopyFiles@2 - displayName: 'Copy native nuget symbols package to: $(Build.ArtifactStagingDirectory)' + - task: MSBuild@1 + displayName: 'Build Nuget Packages' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' + platform: 'Any CPU' + configuration: ${{ parameters.build_config }} + msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=$(targetArchitecture)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: CopyFiles@2 + displayName: 'Copy native nuget package to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy native nuget symbols package to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' + Contents: '*.snupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - ${{ if eq(parameters.Is1ES, true) }}: + - task: 1ES.PublishPipelineArtifact@1 + displayName: 'Publish Pipeline x64 NuGet Artifact' inputs: - SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' - Contents: '*.snupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishPipelineArtifact@0 + artifactName: ${{ parameters.ArtifactName }} + targetPath: '$(Build.ArtifactStagingDirectory)' + - ${{ else }}: + - task: PublishPipelineArtifact@1 displayName: 'Publish Pipeline x64 NuGet Artifact' inputs: artifactName: ${{ parameters.ArtifactName }} diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 7991916a47ca4..52dbb76632e0c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -62,10 +62,14 @@ stages: dependsOn: '${{parameters.InitialStageDependsOn}}' jobs: - job: ReactNative_CI_iOS - pool: - name: 'Azure Pipelines' - image: 'macOS-13' - os: 'macOS' + ${{ if eq(parameters.is1ES, false) }}: + pool: + vmImage: 'macOS-13' + ${{ if eq(parameters.is1ES, true) }}: + pool: + name: 'Azure Pipelines' + image: 'macOS-13' + os: 'macOS' timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index 87836880cbdb8..2e3589ee87c29 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -83,9 +83,6 @@ stages: git submodule update --init -- cmake/external/onnx workingDirectory: '$(Build.SourcesDirectory)' displayName: 'Checkout submodule onnx' - - task: NodeTool@0 - inputs: - versionSpec: '20.x' - template: linux-web-init-and-check.yml - task: Bash@3 displayName: 'Extract commit SHA and save to __commit.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 600e6d857185f..69a06c3db24b8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -161,7 +161,7 @@ stages: displayName: 'Generate cmake config' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --build --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} $(timeoutParameter) $(buildJavaParameter)' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --build --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} $(timeoutParameter) $(buildJavaParameter)' workingDirectory: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index b77cab6a19ba0..6868043f64d81 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -88,10 +88,18 @@ jobs: inputs: sourceFolder: $(Pipeline.Workspace)\artifacts contents: | - **\*.* + **\ort-*.wasm targetFolder: $(Build.SourcesDirectory)\js\web\dist flattenFolders: true - displayName: 'Binplace dist files' + displayName: 'Binplace dist files (.wasm)' + - task: CopyFiles@2 + inputs: + sourceFolder: $(Pipeline.Workspace)\artifacts + contents: | + **\ort-*.mjs + targetFolder: $(Build.SourcesDirectory)\js\web\dist + flattenFolders: true + displayName: 'Binplace dist files (.mjs)' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml index e201cc0ffdd5a..00df695889b1d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml @@ -44,10 +44,18 @@ jobs: inputs: sourceFolder: $(Pipeline.Workspace)\artifacts contents: | - **\*.* + **\ort-*.wasm targetFolder: $(Build.SourcesDirectory)\js\web\dist flattenFolders: true - displayName: 'Binplace dist files' + displayName: 'Binplace dist files (.wasm)' + - task: CopyFiles@2 + inputs: + sourceFolder: $(Pipeline.Workspace)\artifacts + contents: | + **\ort-*.mjs + targetFolder: $(Build.SourcesDirectory)\js\web\dist + flattenFolders: true + displayName: 'Binplace dist files (.mjs)' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js' diff --git a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml index fb3ebdc760a7b..355a575307f0b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml @@ -89,7 +89,7 @@ jobs: # must call vsdevcmd first to add cmake to PATH - script: | python --version - python "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" + python "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --use_vcpkg --use_vcpkg_ms_internal_asset_cache --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml index bb6c210161952..a0f22fcfce14e 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-webgpu-ci-pipeline.yml @@ -105,3 +105,31 @@ stages: onnxruntime_webgpu_external_dawn_test.exe --no_proc_table displayName: Run tests (onnxruntime_webgpu_external_dawn_test) workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + +- stage: webgpu_minimal_build_edge + dependsOn: [] + jobs: + - template: templates/jobs/win-ci-vs-2022-job.yml + parameters: + BuildConfig: 'RelWithDebInfo' + EnvSetupScript: setup_env.bat + buildArch: x64 + additionalBuildFlags: >- + --build_shared_lib + --disable_exceptions + --disable_rtti + --enable_msvc_static_runtime + --enable_reduced_operator_type_support + --skip_tests + --use_binskim_compliant_compile_flags + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF onnxruntime_DISABLE_SPARSE_TENSORS=ON onnxruntime_DISABLE_OPTIONAL_TYPE=ON + --minimal_build extended + --use_webgpu + msbuildPlatform: x64 + isX86: false + job_name_suffix: x64_RelWithDebInfo + RunOnnxRuntimeTests: false + ORT_EP_NAME: WebGPU + EnablePython: false + WITH_CACHE: true + MachinePool: onnxruntime-Win2022-VS2022-webgpu-A10 diff --git a/tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml new file mode 100644 index 0000000000000..f95ac526886fa --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-openvino-ci-pipeline.yml @@ -0,0 +1,116 @@ +##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +### please do rerun set-trigger-rules.py ### +trigger: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +pr: + branches: + include: + - main + - rel-* + paths: + exclude: + - docs/** + - README.md + - CONTRIBUTING.md + - BUILD.md + - 'js/web' + - 'onnxruntime/core/providers/js' +#### end trigger #### + +jobs: +- job: 'BUILD_OPENVINO_EP' + pool: 'onnxruntime-Win-CPU-2022' + variables: + MsbuildArguments: '-detailedsummary -maxcpucount -consoleloggerparameters:PerformanceSummary' + OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + buildArch: x64 + setVcvars: true + BuildConfig: 'RelWithDebInfo' + ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + timeoutInMinutes: 240 + workspace: + clean: all + steps: + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + architecture: $(buildArch) + + - template: templates/jobs/download_win_openvino.yml + + - powershell: | + Write-Output "Setting up OpenVINO environment variables" + . "$(OpenVINORootDir)\setupvars.ps1" + + Write-Output "Exporting selected environment variables to pipeline" + + $vars = @( + "INTEL_OPENVINO_DIR", + "OpenVINO_DIR", + "OpenVINOGenAI_DIR", + "OPENVINO_LIB_PATHS", + "TBB_DIR", + "PATH", + "PYTHONPATH" + ) + + foreach ($var in $vars) { + if (Test-Path "Env:$var") { + $value = [System.Environment]::GetEnvironmentVariable($var, "Process") + Write-Output "Setting $var" + Write-Output "##vso[task.setvariable variable=$var;]$value" + } else { + Write-Output "Warning: $var is not set." + } + } + + Write-Output "Selected environment variables exported successfully" + displayName: 'Set up OpenVINO environment' + + - template: templates/jobs/win-ci-build-steps.yml + parameters: + WithCache: True + Today: $(TODAY) + AdditionalKey: "win-openvino | $(BuildConfig)" + BuildPyArguments: >- + --config $(BuildConfig) + --build_dir $(Build.BinariesDirectory) + --cmake_generator "Visual Studio 17 2022" + --build_shared_lib + --use_openvino CPU + --use_binskim_compliant_compile_flags + --update --parallel + MsbuildArguments: $(MsbuildArguments) + BuildArch: $(buildArch) + Platform: 'x64' + BuildConfig: $(BuildConfig) + + - powershell: | + Write-Output "Getting CPU information" + Get-WmiObject Win32_Processor | Select-Object Name, NumberOfCores, NumberOfLogicalProcessors, Architecture | Format-Table -AutoSize + + Write-Output "Starting unit tests" + python "$(Build.SourcesDirectory)\tools\ci_build\build.py" ` + --config "$(BuildConfig)" ` + --build_dir "$(Build.BinariesDirectory)" ` + --cmake_generator "Visual Studio 17 2022" ` + --build_shared_lib ` + --use_openvino CPU ` + --use_binskim_compliant_compile_flags ` + --test --enable_onnx_tests + displayName: 'Run unit tests' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index e08d7eb2b12de..1c3d911fa7dbb 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -90,7 +90,7 @@ jobs: --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --cmake_generator "Visual Studio 17 2022" - --build_shared_lib + --build_shared_lib --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_qnn $(QnnLibKind) --qnn_home $(QnnSDKRootDir) --update --build --parallel diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 81de3335a07d2..faef469e010f6 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -78,7 +78,7 @@ jobs: --build_dir $(Build.BinariesDirectory) --cmake_generator "Visual Studio 17 2022" --build_java - --build_shared_lib + --build_shared_lib --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_qnn $(QnnLibKind) --qnn_home $(QnnSDKRootDir) --use_binskim_compliant_compile_flags diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index 78f59452d1284..899aaaa95216a 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -16,8 +16,6 @@ "android-x86_64-crosscompile-ci-pipeline.yml", "bigmodels-ci-pipeline.yml", "linux-ci-pipeline.yml", - "linux-cpu-aten-pipeline.yml", - "linux-cpu-eager-pipeline.yml", "linux-dnnl-ci-pipeline.yml", "linux-gpu-ci-pipeline.yml", "linux-gpu-tensorrt-ci-pipeline.yml", @@ -36,6 +34,7 @@ "win-gpu-doc-gen-ci-pipeline.yml", "win-gpu-tensorrt-ci-pipeline.yml", "win-gpu-webgpu-ci-pipeline.yml", + "win-openvino-ci-pipeline.yml", "win-qnn-arm64-ci-pipeline.yml", "win-qnn-ci-pipeline.yml", ] diff --git a/tools/nuget/generate_nuspec_for_custom_nuget.py b/tools/nuget/generate_nuspec_for_custom_nuget.py new file mode 100644 index 0000000000000..baf46743cbf1b --- /dev/null +++ b/tools/nuget/generate_nuspec_for_custom_nuget.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import glob +import os +import shutil + +from generate_nuspec_for_native_nuget import generate_metadata + + +def generate_files(lines, args): + files_list = [""] + platform_map = { + "win-arm64": args.win_arm64, + "win-x64": args.win_x64, + } + + avoid_keywords = {"pdb"} + processed_includes = set() + for platform, platform_dir in platform_map.items(): + for file in glob.glob(os.path.join(platform_dir, "lib", "*")): + if not os.path.isfile(file): + continue + if any(keyword in file for keyword in avoid_keywords): + continue + file_name = os.path.basename(file) + + files_list.append(f'') + + for file in glob.glob(os.path.join(platform_dir, "include", "*")): + if not os.path.isfile(file): + continue + file_name = os.path.basename(file) + if file_name in processed_includes: + continue + processed_includes.add(file_name) + files_list.append(f'') + + files_list.append( + f'' + ) + + files_list.append(f'') + files_list.append( + f'' + ) + files_list.append(f'') + files_list.append( + f'' + ) + + source_props = os.path.join( + args.root_dir, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + "props.xml", + ) + target_props = os.path.join( + args.root_dir, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + f"{args.package_name}.props", + ) + shutil.copyfile(source_props, target_props) + files_list.append(f'') + files_list.append(f'') + + source_targets = os.path.join( + args.root_dir, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + "targets.xml", + ) + target_targets = os.path.join( + args.root_dir, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + f"{args.package_name}.targets", + ) + shutil.copyfile(source_targets, target_targets) + files_list.append(f'') + files_list.append(f'') + + files_list.append("") + lines.extend(files_list) + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Create a nuspec file for the custom nuget package.", + ) + + parser.add_argument("--nuspec_path", required=True, help="Nuspec output file path.") + parser.add_argument("--root_dir", required=True, help="ORT repository root.") + parser.add_argument( + "--commit_id", + required=True, + help="The last commit id included in this package.", + ) + parser.add_argument("--win_arm64", required=True, help="Ort win-arm64 directory") + parser.add_argument("--win_x64", required=True, help="Ort win-x64 directory") + parser.add_argument("--package_version", required=True, help="Version of the package") + parser.add_argument("--package_name", required=True, help="Name of the package") + + args = parser.parse_args() + + args.sdk_info = "" + + return args + + +def generate_nuspec(args: argparse.Namespace): + lines = [''] + lines.append("") + + generate_metadata(lines, args) + generate_files(lines, args) + + lines.append("") + return lines + + +def main(): + args = parse_arguments() + + lines = generate_nuspec(args) + + with open(os.path.join(args.nuspec_path), "w") as f: + for line in lines: + # Uncomment the printing of the line if you need to debug what's produced on a CI machine + print(line) + f.write(line) + f.write("\n") + + +if __name__ == "__main__": + main() diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index 1546a9143831a..aca5f1df7d18b 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -24,6 +24,7 @@ def get_pipeline_names(): "Windows GPU DML CI Pipeline", "Windows GPU Doc Gen CI Pipeline", "Windows GPU TensorRT CI Pipeline", + "Windows OpenVINO CI Pipeline", "ONNX Runtime Web CI Pipeline", "Win_TRT_Minimal_CUDA_Test_CI", # linux diff --git a/tools/python/util/__init__.py b/tools/python/util/__init__.py index a669963e84bcf..8631218ca9e00 100644 --- a/tools/python/util/__init__.py +++ b/tools/python/util/__init__.py @@ -7,7 +7,8 @@ from .run import run # noqa: F401 from .vcpkg_helpers import ( # noqa: F401 generate_android_triplets, - generate_posix_triplets, + generate_linux_triplets, + generate_macos_triplets, generate_vcpkg_triplets_for_emscripten, generate_windows_triplets, ) diff --git a/tools/python/util/vcpkg_helpers.py b/tools/python/util/vcpkg_helpers.py index d33b2f7675690..875a6186e55c2 100644 --- a/tools/python/util/vcpkg_helpers.py +++ b/tools/python/util/vcpkg_helpers.py @@ -222,6 +222,7 @@ def generate_triplet_for_posix_platform( enable_asan: bool, crt_linkage: str, target_abi: str, + osx_deployment_target: str, ) -> None: """ Generate triplet file for POSIX platforms (Linux, macOS). @@ -235,6 +236,7 @@ def generate_triplet_for_posix_platform( enable_asan (bool): Flag indicating if AddressSanitizer is enabled. crt_linkage (str): The CRT linkage type ("static" or "dynamic"). target_abi (str): The target ABI, which maps to the VCPKG_TARGET_ARCHITECTURE variable. Valid options include x86, x64, arm, arm64, arm64ec, s390x, ppc64le, riscv32, riscv64, loongarch32, loongarch64, mips64. + osx_deployment_target (str, optional): The macOS deployment target version. The parameter sets the minimum macOS version for compiled binaries. It also changes what versions of the macOS platform SDK CMake will search for. See the CMake documentation for CMAKE_OSX_DEPLOYMENT_TARGET for more information. """ folder_name_parts = [] if enable_asan: @@ -341,6 +343,8 @@ def generate_triplet_for_posix_platform( else: osx_abi = target_abi f.write(f'set(VCPKG_OSX_ARCHITECTURES "{osx_abi}")\n') + if osx_deployment_target: + f.write(f'set(VCPKG_OSX_DEPLOYMENT_TARGET "{osx_deployment_target}")\n') f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") f.write( "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" @@ -501,32 +505,58 @@ def generate_windows_triplets(build_dir: str) -> None: add_port_configs(f, enable_exception, False) -def generate_posix_triplets(build_dir: str) -> None: +def generate_linux_triplets(build_dir: str) -> None: """ - Generate triplet files for POSIX platforms (Linux, macOS). + Generate triplet files for Linux platforms. Args: build_dir (str): The directory to save the generated triplet files. """ - for os_name in ["linux", "osx"]: - if os_name == "linux": - target_abis = ["x86", "x64", "arm", "arm64", "s390x", "ppc64le", "riscv64", "loongarch64", "mips64"] - else: - target_abis = ["x64", "arm64", "universal2"] - for enable_rtti in [True, False]: - for enable_exception in [True, False]: - for enable_binskim in [True, False]: - for enable_asan in [True, False]: - if enable_asan and enable_binskim: - continue - for target_abi in target_abis: - generate_triplet_for_posix_platform( - build_dir, - os_name, - enable_rtti, - enable_exception, - enable_binskim, - enable_asan, - "dynamic", - target_abi, - ) + target_abis = ["x86", "x64", "arm", "arm64", "s390x", "ppc64le", "riscv64", "loongarch64", "mips64"] + for enable_rtti in [True, False]: + for enable_exception in [True, False]: + for enable_binskim in [True, False]: + for enable_asan in [True, False]: + if enable_asan and enable_binskim: + continue + for target_abi in target_abis: + generate_triplet_for_posix_platform( + build_dir, + "linux", + enable_rtti, + enable_exception, + enable_binskim, + enable_asan, + "dynamic", + target_abi, + None, + ) + + +def generate_macos_triplets(build_dir: str, osx_deployment_target: str) -> None: + """ + Generate triplet files for macOS platforms. + + Args: + build_dir (str): The directory to save the generated triplet files. + osx_deployment_target (str, optional): The macOS deployment target version. + """ + target_abis = ["x64", "arm64", "universal2"] + for enable_rtti in [True, False]: + for enable_exception in [True, False]: + for enable_binskim in [True, False]: + for enable_asan in [True, False]: + if enable_asan and enable_binskim: + continue + for target_abi in target_abis: + generate_triplet_for_posix_platform( + build_dir, + "osx", + enable_rtti, + enable_exception, + enable_binskim, + enable_asan, + "dynamic", + target_abi, + osx_deployment_target, + ) diff --git a/winml/adapter/winml_adapter_model.cpp b/winml/adapter/winml_adapter_model.cpp index 195bf6e5f0ffd..cf02c6fa2328b 100644 --- a/winml/adapter/winml_adapter_model.cpp +++ b/winml/adapter/winml_adapter_model.cpp @@ -593,13 +593,13 @@ ORT_API_STATUS_IMPL( input.set_name(input_name); if (info->type == ONNXType::ONNX_TYPE_TENSOR) { - auto num_dims = info->data->shape.NumDimensions(); + auto num_dims = info->tensor_type_info->shape.NumDimensions(); CreateTypeProto_Tensor( input.mutable_type()->mutable_tensor_type(), input_name, - (num_dims == 0) ? nullptr : &info->data->shape[0], + (num_dims == 0) ? nullptr : &info->tensor_type_info->shape[0], num_dims, - ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type) + ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type) ); } return nullptr; @@ -619,12 +619,12 @@ ORT_API_STATUS_IMPL( ONNX_NAMESPACE::TensorProto& input = *graph.add_initializer(); input.set_name(input_name); - auto num_dims = info->data->shape.NumDimensions(); + auto num_dims = info->tensor_type_info->shape.NumDimensions(); for (size_t i = 0; i < num_dims; i++) { - input.add_dims(info->data->shape[i]); + input.add_dims(info->tensor_type_info->shape[i]); } - input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)); + input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type)); auto tensor = value->GetMutable(); input.set_raw_data(tensor->DataRaw(), tensor->SizeInBytes()); @@ -645,9 +645,9 @@ ORT_API_STATUS_IMPL( CreateTypeProto_Tensor( output.mutable_type()->mutable_tensor_type(), output_name, - &info->data->shape[0], - info->data->shape.NumDimensions(), - ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type) + &info->tensor_type_info->shape[0], + info->tensor_type_info->shape.NumDimensions(), + ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type) ); } return nullptr; From 61b36ef3a9e5414d040deb435a95392f514482f9 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 Date: Tue, 4 Mar 2025 21:07:38 -0800 Subject: [PATCH 011/138] [OVEP] Fix for precision accuracy --- .../openvino/backends/basic_backend.cc | 4 +- .../openvino/openvino_parser_utils.cc | 86 +++++++++++++++++++ .../openvino/openvino_parser_utils.h | 22 +++++ .../openvino/openvino_provider_factory.cc | 60 +------------ 4 files changed, 113 insertions(+), 59 deletions(-) create mode 100644 onnxruntime/core/providers/openvino/openvino_parser_utils.cc create mode 100644 onnxruntime/core/providers/openvino/openvino_parser_utils.h diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index d026ce386e5c3..2e808333fd61d 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -158,10 +158,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { if (session_context_.precision.find("FP32") != std::string::npos) { device_config.emplace(ov::hint::inference_precision("f32")); } - if (session_context_.precision.find("ACCURACY") != std::string::npos && - session_context_.device_type.find("GPU") != std::string::npos) { + if (session_context_.precision.find("ACCURACY") != std::string::npos) { if (session_context_.OpenVINO_Version.at(0) >= 2024) { - device_config.emplace(ov::hint::inference_precision(ov::element::dynamic)); device_config.emplace(ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY)); } else { if (!subgraph_context_.model_precision.empty()) diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc new file mode 100644 index 0000000000000..a7e17d1b8e498 --- /dev/null +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc @@ -0,0 +1,86 @@ +#include +#include "core/providers/openvino/openvino_parser_utils.h" +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +namespace openvino_ep { + +std::string OpenVINOParserUtils::ParsePrecision(const ProviderOptions& provider_options, + std::string& device_type, + const std::string& option_name) { + using DeviceName = std::string; + using DefaultValue = std::string; + using ValidValues = std::list; + using foo = std::pair; + using ParserHelper = std::map; + + ParserHelper helper = { + {"GPU", {"FP16", {"FP16", "FP32", "ACCURACY"}}}, + {"NPU", {"FP16", {"FP16", "ACCURACY"}}}, + {"CPU", {"FP32", {"FP32", "ACCURACY"}}}, + }; + + std::set deprecated_device_types = { + "CPU_FP32", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", + "GPU.0_FP16", "GPU.1_FP16"}; + + bool is_composite = device_type.find(':') != std::string::npos; // FOR devices AUTO:,HETERO:,MULTI: + + if (provider_options.contains(option_name)) { + const auto& precision = provider_options.at(option_name); + + if (is_composite) { + std::set allowed_precisions = {"FP16", "FP32", "ACCURACY"}; + if (allowed_precisions.contains(precision)) { + return precision; + } else { + ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. ", + precision, ".\n"); + } + } else { + if (helper.contains(device_type)) { + auto const& valid_values = helper[device_type].second; + + if (precision == "ACCURACY") { + return valid_values.back(); // Return highest supported precision + } else { + if (std::find(valid_values.begin(), valid_values.end(), precision) != valid_values.end()) { + return precision; // Return precision selected if valid + } else { + auto value_iter = valid_values.begin(); + std::string valid_values_joined = *value_iter; + // Append 2nd and up, if only one then ++value_iter is same as end() + for (++value_iter; value_iter != valid_values.end(); ++value_iter) { + valid_values_joined += ", " + *value_iter; + } + + ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. ", + device_type, " only supports", valid_values_joined, ".\n"); + } + } + } else if (deprecated_device_types.contains(device_type)) { + LOGS_DEFAULT(WARNING) + << "[OpenVINO] Selected 'device_type' " + device_type + " is deprecated. \n" + << "Update the 'device_type' to specified types 'CPU', 'GPU', 'GPU.0', " + << "'GPU.1', 'NPU' or from HETERO/MULTI/AUTO options and set 'precision' separately. \n"; + auto delimit = device_type.find("_"); + device_type = device_type.substr(0, delimit); + return device_type.substr(delimit + 1); + } else { + ORT_THROW("[ERROR] [OpenVINO] Unsupported device type provided: ", + device_type, "\n"); + } + } + } else { + if (device_type.find("NPU") != std::string::npos || device_type.find("GPU") != std::string::npos) { + return "FP16"; + } else if (device_type.find("CPU") != std::string::npos) { + return "FP32"; + } else { + ORT_THROW("[ERROR] [OpenVINO] Unsupported device is selected", device_type, "\n"); + } + } +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.h b/onnxruntime/core/providers/openvino/openvino_parser_utils.h new file mode 100644 index 0000000000000..3e23c9e788463 --- /dev/null +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "core/framework/provider_options.h" + +namespace onnxruntime { +namespace openvino_ep { + +class OpenVINOParserUtils { + public: + static std::string ParsePrecision(const ProviderOptions& provider_options, + std::string& device_type, + const std::string& option_name); +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index c4fe16e035241..a880c24760707 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -11,11 +11,12 @@ #include "core/providers/openvino/backend_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "nlohmann/json.hpp" +#include "core/providers/openvino/openvino_parser_utils.h" namespace onnxruntime { namespace openvino_ep { void ParseConfigOptions(ProviderInfo& pi) { - if(pi.config_options==NULL) + if (pi.config_options == NULL) return; pi.so_disable_cpu_ep_fallback = pi.config_options->GetConfigOrDefault(kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; @@ -29,7 +30,6 @@ void ParseConfigOptions(ProviderInfo& pi) { map["NPU_COMPILATION_MODE_PARAMS"] = "enable-wd-blockarg-input=true compute-layers-with-higher-precision=Sqrt,Power,ReduceSum"; pi.load_config["NPU"] = std::move(map); } - } void* ParseUint64(const ProviderOptions& provider_options, std::string option_name) { @@ -115,58 +115,6 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio } } -// Depends on ProviderOptions. -std::string ParsePrecision(const ProviderOptions& provider_options, std::string& device_type, const std::string& option_name) { - using DeviceName = std::string; - using DefaultValue = std::string; - using ValidValues = std::list; - using foo = std::pair; - using ParserHelper = std::map; - ParserHelper helper = { - {"GPU", {"FP16", {"FP16", "FP32"}}}, - {"NPU", {"FP16", {"FP16"}}}, - {"CPU", {"FP32", {"FP32"}}}, - }; - - std::set deprecated_device_types = {"CPU_FP32", "GPU_FP32", - "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16"}; - - if (provider_options.contains(option_name)) { - // Start by checking if the device_type is a normal valid one - if (helper.contains(device_type)) { - auto const& valid_values = helper[device_type].second; - const auto& precision = provider_options.at(option_name); - if (precision == "ACCURACY") { - return valid_values.back(); // Return highest supported precision - } else { - if (std::find(valid_values.begin(), valid_values.end(), precision) != valid_values.end()) { - return precision; // Return precision selected if valid - } else { - auto value_iter = valid_values.begin(); - std::string valid_values_joined = *value_iter; - // Append 2nd and up, if only one then ++value_iter is same as end() - for (++value_iter; value_iter != valid_values.end(); ++value_iter) { - valid_values_joined += ", " + *value_iter; - } - - ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. ", device_type, " only supports", valid_values_joined, ".\n"); - } - } - } else if (deprecated_device_types.contains(device_type)) { - LOGS_DEFAULT(WARNING) << "[OpenVINO] Selected 'device_type' " + device_type + " is deprecated. \n" - << "Update the 'device_type' to specified types 'CPU', 'GPU', 'GPU.0', " - << "'GPU.1', 'NPU' or from" - << " HETERO/MULTI/AUTO options and set 'precision' separately. \n"; - auto delimit = device_type.find("_"); - device_type = device_type.substr(0, delimit); - return device_type.substr(delimit + 1); - } - } - // Return default - return helper[device_type].first; -} - void ParseProviderOptions([[maybe_unused]] ProviderInfo& result, [[maybe_unused]] const ProviderOptions& config_options) {} struct OpenVINOProviderFactory : IExecutionProviderFactory { @@ -204,7 +152,7 @@ struct OpenVINO_Provider : Provider { const ProviderOptions* provider_options_ptr = reinterpret_cast(pointers_array[0]); const ConfigOptions* config_options = reinterpret_cast(pointers_array[1]); - if(provider_options_ptr == NULL) { + if (provider_options_ptr == NULL) { LOGS_DEFAULT(ERROR) << "[OpenVINO EP] Passed NULL ProviderOptions to CreateExecutionProviderFactory()"; return nullptr; } @@ -234,7 +182,7 @@ struct OpenVINO_Provider : Provider { pi.cache_dir = provider_options.at("cache_dir"); } - pi.precision = ParsePrecision(provider_options, pi.device_type, "precision"); + pi.precision = OpenVINOParserUtils::ParsePrecision(provider_options, pi.device_type, "precision"); if (provider_options.contains("load_config")) { auto parse_config = [&](const std::string& config_str) -> std::map { From 7683e37baafb545a5806e343f266a081c4556256 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Mon, 17 Mar 2025 10:29:53 -0700 Subject: [PATCH 012/138] Refactor OVRTAllocator to return base pointer of remote tensor (#613) This change allows for allocations made by the ov allocator to be imported to other APIs that require base addresses to the original device allocation. --- .../core/providers/openvino/ov_allocator.cc | 29 ++++++++----------- .../core/providers/openvino/ov_allocator.h | 5 ++++ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_allocator.cc b/onnxruntime/core/providers/openvino/ov_allocator.cc index 0e5ff8ff98efb..431f5730c0342 100644 --- a/onnxruntime/core/providers/openvino/ov_allocator.cc +++ b/onnxruntime/core/providers/openvino/ov_allocator.cc @@ -10,12 +10,6 @@ namespace onnxruntime { using namespace openvino_ep; -constexpr size_t default_alignment = 4096; - -static inline size_t align_up(size_t size, size_t pow2_alignment) { - return (size + pow2_alignment - 1) & ~(pow2_alignment - 1); -} - OVRTAllocator::OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char* name) : IAllocator(OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(device_type, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeCPUInput)), core_(core) { if (device_type == OrtDevice::NPU) { remote_ctx_ = core_.get_default_context("NPU").as(); @@ -26,16 +20,11 @@ OVRTAllocator::OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, void* OVRTAllocator::Alloc(size_t size) { try { - size_t alloc_size = align_up(size + sizeof(ov::Tensor*) + default_alignment, default_alignment); ov::Tensor* tensor = new ov::Tensor(remote_ctx_.create_host_tensor(ov::element::Type_t::u8, - {alloc_size})); - uintptr_t data_ptr = reinterpret_cast(tensor->data()); - - ov::Tensor** ptr = reinterpret_cast(align_up(data_ptr + sizeof(ov::Tensor*), default_alignment)); - ptr[-1] = tensor; - - return reinterpret_cast(ptr); - + {size})); + std::unique_lock lock(mutex_); + allocated_.insert({tensor->data(), tensor}); + return reinterpret_cast(tensor->data()); } catch (const ov::Exception& e) { ORT_THROW(std::string("Alloc failed: ") + e.what()); } @@ -43,8 +32,14 @@ void* OVRTAllocator::Alloc(size_t size) { void OVRTAllocator::Free(void* p) { try { - ov::Tensor** ptr = reinterpret_cast(p); - delete ptr[-1]; + std::unique_lock lock(mutex_); + auto it = allocated_.find(p); + if (it != allocated_.end()) { + ov::Tensor* tensor = it->second; + allocated_.erase(it); + lock.unlock(); + delete tensor; + } } catch (const ov::Exception& e) { ORT_THROW(std::string("Free failed: ") + e.what()); } diff --git a/onnxruntime/core/providers/openvino/ov_allocator.h b/onnxruntime/core/providers/openvino/ov_allocator.h index 083cfc4d5aed3..f6e87111f47ff 100644 --- a/onnxruntime/core/providers/openvino/ov_allocator.h +++ b/onnxruntime/core/providers/openvino/ov_allocator.h @@ -3,9 +3,12 @@ #ifdef USE_OVEP_NPU_MEMORY #pragma once +#include + #include "core/common/inlined_containers.h" #include "core/framework/allocator.h" #include "openvino/runtime/remote_context.hpp" +#include "core/common/inlined_containers.h" namespace onnxruntime { @@ -18,6 +21,8 @@ class OVRTAllocator : public IAllocator { private: ov::Core& core_; ov::RemoteContext remote_ctx_; + InlinedHashMap allocated_; + std::mutex mutex_; }; } // namespace onnxruntime From 1788576673d5606d27ea97ebcdddf6c908ac5929 Mon Sep 17 00:00:00 2001 From: sfatimar Date: Tue, 18 Mar 2025 14:05:12 +0530 Subject: [PATCH 013/138] Commit Lint Errors fix (#606) --- .../providers/openvino/openvino_provider_factory.cc | 11 +++-------- onnxruntime/core/session/provider_bridge_ort.cc | 2 +- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index a880c24760707..f6251ee7049a7 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -16,7 +16,8 @@ namespace onnxruntime { namespace openvino_ep { void ParseConfigOptions(ProviderInfo& pi) { - if (pi.config_options == NULL) + + if (pi.config_options == nullptr) return; pi.so_disable_cpu_ep_fallback = pi.config_options->GetConfigOrDefault(kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; @@ -149,15 +150,9 @@ struct OpenVINO_Provider : Provider { } std::array pointers_array = *reinterpret_cast*>(void_params); - const ProviderOptions* provider_options_ptr = reinterpret_cast(pointers_array[0]); + const ProviderOptions provider_options = *reinterpret_cast(pointers_array[0]); const ConfigOptions* config_options = reinterpret_cast(pointers_array[1]); - if (provider_options_ptr == NULL) { - LOGS_DEFAULT(ERROR) << "[OpenVINO EP] Passed NULL ProviderOptions to CreateExecutionProviderFactory()"; - return nullptr; - } - const ProviderOptions provider_options = *provider_options_ptr; - ProviderInfo pi; pi.config_options = config_options; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 2ea4a93d21f2e..e46236f4ca11c 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2020,7 +2020,7 @@ std::shared_ptr OpenVINOProviderFactoryCreator::Creat const ProviderOptions* provider_options_map, const SessionOptions* session_options) { // Append session options applicable for EP to EP Provider options. const ConfigOptions* config_options = nullptr; - if (session_options !=nullptr) { + if (session_options != nullptr) { config_options = &session_options->config_options; } From 23e17e2c0abff0ccd74fbe72b9a53d5ae47eaf9a Mon Sep 17 00:00:00 2001 From: saurabh Date: Tue, 18 Mar 2025 06:46:26 -0700 Subject: [PATCH 014/138] fix quantizedLinear layer feeds into grapg output (#615) --- .../openvino/qdq_transformations/qdq_stripping.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index a9f6420d6ac3b..effd13abc3e3a 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -346,7 +346,7 @@ static bool CheckQFeedsIntoQuantizedOutput(const NodeUnit& node_unit, auto op_of_quantized_layer = node_unit.Outputs(); for (auto& itr : op_of_quantized_layer) { auto it = graph_op_data_type.find(itr.node_arg.Name()); - if (it != graph_op_data_type.end() && it->second == "tensor(uint8)") { + if (it != graph_op_data_type.end() && (it->second == "tensor(uint8)" || it->second == "tensor(uint16)")) { return true; } } @@ -369,6 +369,11 @@ static bool CheckQRuleSet(const NodeUnit& node_unit, graph_op_data_type[src_graph.GetNodeArg(ops->Name())->Name()] = ops->Type()->data(); } + // check If any quantized node feeds into the src graph output + if (CheckQFeedsIntoQuantizedOutput(node_unit, std::move(graph_op_data_type))) { + return true; + } + // If UInt16 Q, don't keep it if (GetQDQDataType(q_node) == DT_UINT16 || GetQDQDataType(q_node) == DT_INT16) { reason = SkipReason::Int16QDQ; @@ -381,9 +386,7 @@ static bool CheckQRuleSet(const NodeUnit& node_unit, } else if (op_type == "Add") { // Add keeps all Qs return true; - } else if (CheckQFeedsIntoQuantizedOutput(node_unit, std::move(graph_op_data_type))) { - return true; - } else { + } else { // Keep Q of an unsupported Op only if the target that succeeds it is a supported Op in this list return IsNextTargetNodeOfQValid(q_node, &target_node, src_graph, {"Conv", "Add", "MatMul"}, false); } From 8b4a6d2c0fdf93c431223cc991d2c0e0149032e7 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Wed, 19 Mar 2025 11:38:58 +0530 Subject: [PATCH 015/138] [OVEP] Fix for dumping the model in correct format (#616) --- onnxruntime/core/providers/openvino/backend_manager.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 2a842b8a1eca8..1ed5ad8a56fd5 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -331,6 +331,7 @@ static void DumpOpenVINOEPModel(const std::filesystem::path& onnx_model_path_nam if (dash != std::string::npos) { auto new_name = model_name.stem().string() + subgraph_name.substr(dash, std::string::npos); model_name.replace_filename(new_name); + model_name.replace_extension(".onnx"); } std::fstream dump(model_name, std::ios::out | std::ios::trunc | std::ios::binary); From 7269615ee423f2f0a0a558a36befa28f03e97f73 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Wed, 19 Mar 2025 13:41:57 +0530 Subject: [PATCH 016/138] [OVEP] Added Cast and Resize to operators that handle zero-valued dimensions, preventing unnecessary fallback (#619) --- onnxruntime/core/providers/openvino/ov_versions/data_ops.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index f7326642a5544..4e1387d2ef4a9 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -792,7 +792,8 @@ bool DataOps::node_is_supported(const NodeIndex node_idx, bool& has_external_wei if (((device_id_.find("CPU") != std::string::npos) || (device_id_.find("GPU") != std::string::npos)) && ((optype == "Expand") || (optype == "Equal") || (optype == "Slice") || (optype == "Concat") || - (optype == "Shape"))) { + (optype == "Shape") || (optype == "Cast") || + (optype == "Resize"))) { return; } has_unsupported_dimension = true; From e240695353f09aca67633b66857a44099e01bce5 Mon Sep 17 00:00:00 2001 From: Pallavi Gupta Date: Thu, 27 Mar 2025 00:42:00 -0700 Subject: [PATCH 017/138] [OVEP] Fix for Dynamic backend creation for NPU. (#622) --- onnxruntime/core/providers/openvino/backends/basic_backend.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 2e808333fd61d..14a4a466613dd 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -425,8 +425,8 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque } } } // Loop subgraph original input names - - if (session_context_.device_type.find("NPU") != std::string::npos) { + if (session_context_.device_type.find("NPU") != std::string::npos && + !subgraph_context_.has_dynamic_input_shape) { // Set the output blob as remote blob auto graph_output_info = exe_network_.Get().outputs(); auto output_idx = 0; From 322a7e1ca3833ba49c82419e54368c07d57589d3 Mon Sep 17 00:00:00 2001 From: Nikolay Proshunin Date: Fri, 28 Mar 2025 15:01:22 +0100 Subject: [PATCH 018/138] Remove unnecessary device queries (#620) * Rewrote ParseDevice, ParsePrecison and OpenVINOExecutionProvider code to avoid calls to get_available_devices. * Addressed issues and added a minor change in ParseDeviceType. * Fixed bug in OpenVINOExecutionProvider constructor. * Added default precision for composite and custom devices. Modified check for consistency. * Fixed multiple GPU logic to allow choosing just GPU without index in a multiple GPU system. --- .../openvino/openvino_execution_provider.cc | 60 +++++--- .../openvino/openvino_parser_utils.cc | 136 +++++++++++------- .../openvino/openvino_provider_factory.cc | 107 ++++++++------ .../core/providers/openvino/ov_interface.cc | 38 ++++- .../core/providers/openvino/ov_interface.h | 3 +- .../core/session/provider_bridge_ort.cc | 2 +- 6 files changed, 223 insertions(+), 123 deletions(-) diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 6482a07ee92bc..da12d7f27d61a 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -62,36 +62,50 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const ProviderInfo& info, s // to check if target device is available // using OVCore capability GetAvailableDevices to fetch list of devices plugged in if (info.cache_dir.empty()) { - bool device_found = false; - std::vector available_devices = OVCore::Get()->GetAvailableDevices(); + bool all_devices_found = false; // Checking for device_type configuration if (info.device_type != "") { - if (info.device_type.find("HETERO") != std::string::npos || - info.device_type.find("MULTI") != std::string::npos || - info.device_type.find("AUTO") != std::string::npos) { - device_found = true; + std::vector devices_to_check; + if (info.device_type.find("HETERO:") == 0 || + info.device_type.find("MULTI:") == 0 || + info.device_type.find("BATCH:") == 0 || + info.device_type.find("AUTO:") == 0) { + auto delimit = info.device_type.find(":"); + const auto& devices = info.device_type.substr(delimit + 1); + devices_to_check = split(devices, ','); } else { - for (const std::string& device : available_devices) { - if (device.rfind(info.device_type, 0) == 0) { - if (info.device_type.find("GPU") != std::string::npos && (info.precision == "FP32" || - info.precision == "FP16" || - info.precision == "ACCURACY")) { - device_found = true; - break; - } - if (info.device_type == "CPU" && (info.precision == "FP32")) { - device_found = true; - break; - } - if (info.device_type.find("NPU") != std::string::npos) { - device_found = true; - break; - } + devices_to_check.push_back(info.device_type); + } + + // Re-initialize before loop + all_devices_found = true; + for (const auto& device : devices_to_check) { + bool device_found = false; + std::string device_prefix = device; + int device_idx = 0; + // Get the index and remove the index from device_prefix + if (auto delimit = device_prefix.find("."); delimit != std::string::npos) { + try { + device_idx = std::stoi(device_prefix.substr(delimit + 1)); + } catch (std::exception& ex) { + ORT_THROW("[ERROR] [OpenVINO] Wrong index in specified device - " + device + " :", ex.what()); } + device_prefix = device_prefix.substr(0, delimit); + } + std::vector available_devices = OVCore::Get()->GetAvailableDevices(device_prefix); + // If idx is 0, maybe index is not set (e.g. GPU) + // Then the device is found if we have at least one device of the type + if (device_idx == 0 && available_devices.size() >= 1) { + device_found = true; + } else { + // Find full device (e.g GPU.1) in the list + if (std::find(std::begin(available_devices), std::end(available_devices), device) != std::end(available_devices)) + device_found = true; } + all_devices_found = all_devices_found && device_found; } } - if (!device_found) { + if (!all_devices_found) { ORT_THROW("[ERROR] [OpenVINO] Specified device - " + info.device_type + " is not available"); } } diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc index a7e17d1b8e498..067076b1f84f2 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc @@ -10,74 +10,108 @@ std::string OpenVINOParserUtils::ParsePrecision(const ProviderOptions& provider_ const std::string& option_name) { using DeviceName = std::string; using DefaultValue = std::string; - using ValidValues = std::list; - using foo = std::pair; - using ParserHelper = std::map; - + using ValidValues = std::vector; + using DefaultAndValidPair = std::pair; + using ParserHelper = std::unordered_map; + // {Device prefix, {Default precision, {Supported precisions}}} ParserHelper helper = { {"GPU", {"FP16", {"FP16", "FP32", "ACCURACY"}}}, {"NPU", {"FP16", {"FP16", "ACCURACY"}}}, {"CPU", {"FP32", {"FP32", "ACCURACY"}}}, }; - std::set deprecated_device_types = { - "CPU_FP32", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16"}; - - bool is_composite = device_type.find(':') != std::string::npos; // FOR devices AUTO:,HETERO:,MULTI: - - if (provider_options.contains(option_name)) { - const auto& precision = provider_options.at(option_name); - - if (is_composite) { - std::set allowed_precisions = {"FP16", "FP32", "ACCURACY"}; - if (allowed_precisions.contains(precision)) { + // If we have multiple device configuration, request precision from user and check it + if ((device_type.find("HETERO:") == 0) || + (device_type.find("MULTI:") == 0) || + (device_type.find("BATCH:") == 0) || + (device_type.find("AUTO:") == 0)) { + if (!provider_options.contains(option_name)) { + LOGS_DEFAULT(INFO) << "[OpenVINO] Precision is not set. Using default OpenVINO precision for " + device_type + ". \n"; + return ""; + } else { + std::unordered_set supported_precisions = {"FP16", "FP32", "ACCURACY"}; + std::string precision = provider_options.at(option_name); + if (supported_precisions.contains(precision)) { return precision; } else { - ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. ", - precision, ".\n"); + ORT_THROW("[ERROR] [OpenVINO] Unsupported precision for the ", device_type, " device. Device supports only FP16, FP32, ACCURACY.\n"); } + } + } + + // Deprecated device specification (CPU_FP32, GPU.0_FP32, etc.) + if (auto delimit = device_type.find("_"); delimit != std::string::npos) { + if (provider_options.contains(option_name)) { + ORT_THROW("[ERROR] [OpenVINO] Precision is specified twice, please remove the _precision suffix from device name and only set the precision separately.\n"); + } + LOGS_DEFAULT(WARNING) << "[OpenVINO] Selected 'device_type' " + device_type + " is deprecated. \n" + << "Update the 'device_type' to specified types 'CPU', 'GPU', 'GPU.0', " + << "'GPU.1', 'NPU' or from" + << " HETERO/MULTI/AUTO/BATCH options and set 'precision' separately. \n"; + std::string precision = device_type.substr(delimit + 1); + // Device type is updated in-place + device_type = device_type.substr(0, delimit); + // We have to remove the index (.0, .1, etc.) to use device as key for helper + std::string device_prefix = device_type; + if (auto dot_delimit = device_prefix.find("."); dot_delimit != std::string::npos) { + device_prefix = device_prefix.substr(0, dot_delimit); + } + + if (!helper.contains(device_prefix)) { + ORT_THROW("[ERROR] [OpenVINO] Selected 'device_type' " + device_type + " is not supported with precision suffix. \n"); + } + const auto& valid_values = helper[device_prefix].second; + if (std::find(std::begin(valid_values), std::end(valid_values), precision) != std::end(valid_values)) { + return precision; } else { - if (helper.contains(device_type)) { - auto const& valid_values = helper[device_type].second; + auto value_iter = valid_values.begin(); + std::string valid_values_joined = *value_iter; + // Append 2nd and up, if only one then ++value_iter is same as end() + for (++value_iter; value_iter != valid_values.end(); ++value_iter) { + valid_values_joined += ", " + *value_iter; + } - if (precision == "ACCURACY") { - return valid_values.back(); // Return highest supported precision - } else { - if (std::find(valid_values.begin(), valid_values.end(), precision) != valid_values.end()) { - return precision; // Return precision selected if valid - } else { - auto value_iter = valid_values.begin(); - std::string valid_values_joined = *value_iter; - // Append 2nd and up, if only one then ++value_iter is same as end() - for (++value_iter; value_iter != valid_values.end(); ++value_iter) { - valid_values_joined += ", " + *value_iter; - } + ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. ", device_type, " only supports ", valid_values_joined, ".\n"); + } + } - ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. ", - device_type, " only supports", valid_values_joined, ".\n"); - } - } - } else if (deprecated_device_types.contains(device_type)) { - LOGS_DEFAULT(WARNING) - << "[OpenVINO] Selected 'device_type' " + device_type + " is deprecated. \n" - << "Update the 'device_type' to specified types 'CPU', 'GPU', 'GPU.0', " - << "'GPU.1', 'NPU' or from HETERO/MULTI/AUTO options and set 'precision' separately. \n"; - auto delimit = device_type.find("_"); - device_type = device_type.substr(0, delimit); - return device_type.substr(delimit + 1); + // Deprecated devices are already handled above + // We have to remove the index (.0, .1, etc.) to use device as key for helper + auto device_prefix = device_type; + if (auto dot_delimit = device_prefix.find("."); dot_delimit != std::string::npos) { + device_prefix = device_prefix.substr(0, dot_delimit); + } + + if (provider_options.contains(option_name)) { + std::string precision = provider_options.at(option_name); + + if (helper.contains(device_prefix)) { + auto const& valid_values = helper[device_prefix].second; + if (std::find(std::begin(valid_values), std::end(valid_values), precision) != std::end(valid_values)) { + return precision; // Return precision selected if valid } else { - ORT_THROW("[ERROR] [OpenVINO] Unsupported device type provided: ", - device_type, "\n"); + auto value_iter = valid_values.begin(); + std::string valid_values_joined = *value_iter; + // Append 2nd and up, if only one then ++value_iter is same as end() + for (++value_iter; value_iter != valid_values.end(); ++value_iter) { + valid_values_joined += ", " + *value_iter; + } + + ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. ", device_type, " only supports ", valid_values_joined, ".\n"); } + } else { + // Not found in helper - custom device, return as is + return precision; } } else { - if (device_type.find("NPU") != std::string::npos || device_type.find("GPU") != std::string::npos) { - return "FP16"; - } else if (device_type.find("CPU") != std::string::npos) { - return "FP32"; + // Precision not set + if (helper.contains(device_prefix)) { + // If found in helper - set the default + return helper[device_prefix].first; } else { - ORT_THROW("[ERROR] [OpenVINO] Unsupported device is selected", device_type, "\n"); + // Not found in helper - custom device - default precision + LOGS_DEFAULT(INFO) << "[OpenVINO] Precision is not set. Using default OpenVINO precision for " + device_type + ". \n"; + return ""; } } } diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index f6251ee7049a7..5809ca2cf1c7b 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -57,63 +57,80 @@ bool ParseBooleanOption(const ProviderOptions& provider_options, std::string opt } std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptions& provider_options, std::string option_name) { - const std::vector ov_available_devices = ov_core->GetAvailableDevices(); - - std::set ov_supported_device_types = {"CPU", "GPU", - "GPU.0", "GPU.1", "NPU"}; - std::set deprecated_device_types = {"CPU_FP32", "GPU_FP32", - "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16"}; - - // Expand set of supported device with OV devices - ov_supported_device_types.insert(ov_available_devices.begin(), ov_available_devices.end()); + // This function normally does not check if the selected device is available, but does some sanity checks + // Only if the device is not standard, then availability is checked. + // Availability is checked for the selected device in the OpenVINOExecutionProvider constructor + std::vector devices_to_check; + std::string selected_device; if (provider_options.contains(option_name)) { - const auto& selected_device = provider_options.at("device_type"); - - if (deprecated_device_types.contains(selected_device)) { - // Deprecated device and precision is handled together at ParsePrecision - return selected_device; - } - - if (!((ov_supported_device_types.contains(selected_device)) || - (selected_device.find("HETERO:") == 0) || - (selected_device.find("MULTI:") == 0) || - (selected_device.find("AUTO:") == 0))) { - ORT_THROW( - "[ERROR] [OpenVINO] You have selected wrong configuration value for the key 'device_type'. " - "Select from 'CPU', 'GPU', 'NPU', 'GPU.x' where x = 0,1,2 and so on or from" - " HETERO/MULTI/AUTO options available. \n"); + selected_device = provider_options.at(option_name); + // If we have multiple device configuration, we need to check all of them + if ((selected_device.find("HETERO:") == 0) || + (selected_device.find("MULTI:") == 0) || + (selected_device.find("BATCH:") == 0) || + (selected_device.find("AUTO:") == 0)) { + auto delimit = selected_device.find(":"); + const auto& devices = selected_device.substr(delimit + 1); + devices_to_check = split(devices, ','); + } else { + devices_to_check.push_back(selected_device); } - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; - return selected_device; } else { - std::string default_device; - // Take default behavior from project configuration #if defined OPENVINO_CONFIG_CPU - default_device = "CPU"; + selected_device = "CPU"; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; + return selected_device; #elif defined OPENVINO_CONFIG_GPU - default_device = "GPU"; + selected_device = "GPU"; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; + return selected_device; #elif defined OPENVINO_CONFIG_NPU - default_device = "NPU"; + selected_device = "NPU"; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; + return selected_device; #elif defined OPENVINO_CONFIG_HETERO || defined OPENVINO_CONFIG_MULTI || defined OPENVINO_CONFIG_AUTO - default_device = DEVICE_NAME; - - // Validate that devices passed are valid - int delimit = device_type.find(":"); - const auto& devices = device_type.substr(delimit + 1); - auto device_list = split(devices, ','); - for (const auto& device : devices) { - if (!ov_supported_device_types.contains(device)) { - ORT_THROW("[ERROR] [OpenVINO] Invalid device selected: ", device); - } - } + selected_device = DEVICE_NAME; + + // Add sub-devices to check-list + int delimit = selected_device.find(":"); + const auto& devices = selected_device.substr(delimit + 1); + devices_to_check = split(devices, ','); #endif + } - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << default_device; - return default_device; + // Devices considered to be supported by default + std::unordered_set supported_device_types = {"CPU", "GPU", "NPU"}; + for (auto device : devices_to_check) { + // Check deprecated device format (CPU_FP32, GPU.0_FP16, etc.) and remove the suffix in place + // Suffix will be parsed in ParsePrecision + if (auto delimit = device.find("_"); delimit != std::string::npos) { + device = device.substr(0, delimit); + } + // Just the device name without .0, .1, etc. suffix + auto device_prefix = device; + // Check if device index is appended (.0, .1, etc.), if so, remove it + if (auto delimit = device_prefix.find("."); delimit != std::string::npos) { + device_prefix = device_prefix.substr(0, delimit); + if (device_prefix == "CPU") + ORT_THROW("[ERROR] [OpenVINO] CPU device is only supported without index, CPU.x is illegal.\n"); + } + // Only device is not supported by default (some exotic device), check if it's available + if (!supported_device_types.contains(device_prefix)) { + std::vector available_devices = ov_core->GetAvailableDevices(); + // Here we need to find the full device name (with .idx, but without _precision) + if (std::find(std::begin(available_devices), std::end(available_devices), device) == std::end(available_devices)) { + ORT_THROW( + "[ERROR] [OpenVINO] You have selected wrong configuration value for the key 'device_type'. " + "Select from 'CPU', 'GPU', 'NPU', 'GPU.x' where x = 0,1,2 and so on or from" + " HETERO/MULTI/AUTO/BATCH options available. \n"); + } + } } + // All devices have passed the check, return selected device + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; + return selected_device; } void ParseProviderOptions([[maybe_unused]] ProviderInfo& result, [[maybe_unused]] const ProviderOptions& config_options) {} diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 9208f6a76e0bc..d5d23cf4a11f1 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -164,8 +164,42 @@ OVExeNetwork OVCore::ImportModel(std::shared_ptr model_strea } #endif -std::vector OVCore::GetAvailableDevices() { - auto available_devices = core.get_available_devices(); +std::vector OVCore::GetAvailableDevices() const { + std::vector available_devices = core.get_available_devices(); + return available_devices; +} + +std::vector OVCore::GetAvailableDevices(const std::string& device_type) const { + std::vector available_devices; + std::vector devicesIDs; + // Uses logic from OpenVINO to only return available devices of the specified type (e.g. CPU, NPU or GPU) + try { + devicesIDs = core.get_property(device_type, ov::available_devices); + } catch (const ov::Exception&) { + // plugin is not created by e.g. invalid env + // Empty device list will be returned + } catch (const std::runtime_error&) { + // plugin is not created by e.g. invalid env + // Empty device list will be returned + } catch (const std::exception& ex) { + ORT_THROW("[ERROR] [OpenVINO] An exception is thrown while trying to create the ", + device_type, + " device: ", + ex.what()); + } catch (...) { + ORT_THROW("[ERROR] [OpenVINO] Unknown exception is thrown while trying to create the ", + device_type, + " device"); + } + + if (devicesIDs.size() > 1) { + for (const auto& deviceID : devicesIDs) { + available_devices.push_back(device_type + '.' + deviceID); + } + } else if (!devicesIDs.empty()) { + available_devices.push_back(device_type); + } + return available_devices; } diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index f58b05e6017ec..7e0f3f7d917a9 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -92,7 +92,8 @@ struct OVCore : WeakSingleton { OVRemoteContextPtr context, std::string name); #endif - std::vector GetAvailableDevices(); + std::vector GetAvailableDevices() const; + std::vector GetAvailableDevices(const std::string& device_type) const; void SetCache(const std::string& cache_dir_path); void SetStreams(const std::string& device_type, int num_streams); }; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 5fd197d7a798b..36bc1d7412704 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1983,7 +1983,7 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O ov_options_converted_map["context"] = context_string.str(); } - ov_options_converted_map["enable_opencl_throttling"] = legacy_ov_options->enable_opencl_throttling; + ov_options_converted_map["enable_opencl_throttling"] = legacy_ov_options->enable_opencl_throttling == 0 ? "true" : "false"; if (legacy_ov_options->enable_dynamic_shapes) { ov_options_converted_map["disable_dynamic_shapes"] = "false"; From 25912f7af5626fb3641fd7fa64a78fe7d6af0a58 Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Wed, 2 Apr 2025 10:51:14 -0700 Subject: [PATCH 019/138] Add support for parsing AUTO, HETERO and MULTI from json config (#605) * Add support for parsing AUTO, HETERO and MULTI from json config * Fix lint issues * Address review comments --- .../openvino/backends/basic_backend.cc | 17 +++++++++++++++++ .../openvino/openvino_provider_factory.cc | 6 ++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 14a4a466613dd..18182284181e5 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License #include +#include + #include #include #include @@ -222,6 +224,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { } } } + auto find_device_type_mode = [&](const std::string& device_type) -> std::string { + std::string device_mode = ""; + auto delimiter_pos = device_type.find(':'); + if (delimiter_pos != std::string::npos) { + std::stringstream str_stream(device_type.substr(0, delimiter_pos)); + std::getline(str_stream, device_mode, ','); + } + return device_mode; + }; // Parse device types like "AUTO:CPU,GPU" and extract individual devices auto parse_individual_devices = [&](const std::string& device_type) -> std::vector { @@ -270,8 +281,14 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { if (session_context_.device_type.find("AUTO") == 0 || session_context_.device_type.find("HETERO") == 0 || session_context_.device_type.find("MULTI") == 0) { + //// Parse to get the device mode (e.g., "AUTO:CPU,GPU" -> "AUTO") + std::unordered_set supported_mode = {"AUTO", "HETERO", "MULTI"}; + auto device_mode = find_device_type_mode(session_context_.device_type); + ORT_ENFORCE(supported_mode.find(device_mode)!=supported_mode.end(), " Invalid device mode is passed : " , session_context_.device_type); // Parse individual devices (e.g., "AUTO:CPU,GPU" -> ["CPU", "GPU"]) auto individual_devices = parse_individual_devices(session_context_.device_type); + if (!device_mode.empty()) individual_devices.emplace_back(device_mode); + // Set properties only for individual devices (e.g., "CPU", "GPU") for (const std::string& device : individual_devices) { if (target_config.count(device)) { diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 5809ca2cf1c7b..b8f9aec3a68ed 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License #include +#include + #include #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/openvino_provider_factory.h" @@ -216,9 +218,9 @@ struct OpenVINO_Provider : Provider { for (auto& [key, value] : json_config.items()) { ov::AnyMap inner_map; - + std::set valid_ov_devices = {"CPU", "GPU", "NPU", "AUTO", "HETERO", "MULTI"}; // Ensure the key is one of "CPU", "GPU", or "NPU" - if (key != "CPU" && key != "GPU" && key != "NPU") { + if (valid_ov_devices.find(key) == valid_ov_devices.end()) { LOGS_DEFAULT(WARNING) << "Unsupported device key: " << key << ". Skipping entry.\n"; continue; } From fbf43a92d93b8082163faa572ab4cd3ae2499f8c Mon Sep 17 00:00:00 2001 From: Pallavi Gupta Date: Wed, 2 Apr 2025 20:05:45 -0700 Subject: [PATCH 020/138] Revert "[OVEP] Fix for Dynamic backend creation for NPU. (#622)" (#635) This reverts commit e240695353f09aca67633b66857a44099e01bce5. --- onnxruntime/core/providers/openvino/backends/basic_backend.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 18182284181e5..04297de038cd3 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -442,8 +442,8 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque } } } // Loop subgraph original input names - if (session_context_.device_type.find("NPU") != std::string::npos && - !subgraph_context_.has_dynamic_input_shape) { + + if (session_context_.device_type.find("NPU") != std::string::npos) { // Set the output blob as remote blob auto graph_output_info = exe_network_.Get().outputs(); auto output_idx = 0; From 2313d11c3656d6696a8207f5a1852b201f9f8c8e Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Thu, 3 Apr 2025 19:33:08 +0530 Subject: [PATCH 021/138] [OVEP] Fix for building OVEP without vcpkg flag (#637) --- tools/ci_build/github/linux/run_build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/linux/run_build.sh b/tools/ci_build/github/linux/run_build.sh index db8c271e3c876..25b3610872a04 100755 --- a/tools/ci_build/github/linux/run_build.sh +++ b/tools/ci_build/github/linux/run_build.sh @@ -37,7 +37,7 @@ if [ $BUILD_OS = "yocto" ]; then make -j$(nproc) else - COMMON_BUILD_ARGS="--skip_submodule_sync --enable_onnx_tests --parallel --use_vcpkg --use_binskim_compliant_compile_flags --cmake_path /usr/bin/cmake --ctest_path /usr/bin/ctest" + COMMON_BUILD_ARGS="--skip_submodule_sync --enable_onnx_tests --parallel --use_binskim_compliant_compile_flags --cmake_path /usr/bin/cmake --ctest_path /usr/bin/ctest" if [ $BUILD_DEVICE = "gpu" ]; then _CUDNN_VERSION=$(echo $CUDNN_VERSION | cut -d. -f1-2) From 1e85f1d30b20243a905d3c002f767b96f136845c Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Fri, 4 Apr 2025 19:33:23 +0530 Subject: [PATCH 022/138] [OVEP] Updated Documentation for python wheels (#640) --- docs/python/ReadMeOV.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/python/ReadMeOV.rst b/docs/python/ReadMeOV.rst index 86914699bbf6d..845f79cf8257c 100644 --- a/docs/python/ReadMeOV.rst +++ b/docs/python/ReadMeOV.rst @@ -7,7 +7,7 @@ OpenVINO™ Execution Provider for ONNX Runtime accelerates inference across man - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated NPUs (Windows only) + - Intel® integrated NPUs Installation ------------ @@ -15,28 +15,28 @@ Installation Requirements ^^^^^^^^^^^^ -- Ubuntu 18.04, 20.04, RHEL(CPU only) or Windows 10 - 64 bit -- Python 3.9 or 3.10 or 3.11 for Linux and Python 3.10, 3.11 for Windows +- Ubuntu 18.04, 20.04 or Windows 10 - 64 bit +- Python 3.11, 3.12 and 3.13 for Windows and Linux This package supports: - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated NPUs (Windows only) + - Intel® integrated NPUs ``pip3 install onnxruntime-openvino`` Please install OpenVINO™ PyPi Package separately for Windows. For installation instructions on Windows please refer to `OpenVINO™ Execution Provider for ONNX Runtime for Windows `_. -**OpenVINO™ Execution Provider for ONNX Runtime** Linux Wheels comes with pre-built libraries of OpenVINO™ version 2024.1.0 eliminating the need to install OpenVINO™ separately. +**OpenVINO™ Execution Provider for ONNX Runtime** Linux Wheels comes with pre-built libraries of OpenVINO™ version 2025.0.0 eliminating the need to install OpenVINO™ separately. For more details on build and installation please refer to `Build `_. Usage ^^^^^ -By default, Intel® CPU is used to run inference. However, you can change the default option to either Intel® integrated GPU, discrete GPU, integrated NPU (Windows only). +By default, Intel® CPU is used to run inference. However, you can change the default option to either Intel® integrated GPU, discrete GPU, integrated NPU. Invoke `the provider config device type argument `_ to change the hardware on which inferencing is done. For more API calls and environment variables, see `Usage `_. From 80dfee9ffbdf74ed953bdd75aedbdb8be555394b Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Fri, 4 Apr 2025 14:11:05 -0700 Subject: [PATCH 023/138] Device type refactoring (#630) * Refactor device check logic * Validate the provider options key passed * Add support for mapping LUID to device * Fix lint warnings * generalise LUID for GPU and NPU --- .../openvino/backends/basic_backend.cc | 2 +- .../core/providers/openvino/contexts.h | 4 + .../openvino/openvino_execution_provider.cc | 51 ------- .../openvino/openvino_provider_factory.cc | 134 +++++++++++++----- .../core/providers/openvino/ov_interface.cc | 16 ++- .../qdq_transformations/qdq_stripping.cc | 2 +- onnxruntime/test/perftest/ort_test_session.cc | 8 +- 7 files changed, 118 insertions(+), 99 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 04297de038cd3..c814df618e3b3 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -284,7 +284,7 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { //// Parse to get the device mode (e.g., "AUTO:CPU,GPU" -> "AUTO") std::unordered_set supported_mode = {"AUTO", "HETERO", "MULTI"}; auto device_mode = find_device_type_mode(session_context_.device_type); - ORT_ENFORCE(supported_mode.find(device_mode)!=supported_mode.end(), " Invalid device mode is passed : " , session_context_.device_type); + ORT_ENFORCE(supported_mode.find(device_mode) != supported_mode.end(), " Invalid device mode is passed : ", session_context_.device_type); // Parse individual devices (e.g., "AUTO:CPU,GPU" -> ["CPU", "GPU"]) auto individual_devices = parse_individual_devices(session_context_.device_type); if (!device_mode.empty()) individual_devices.emplace_back(device_mode); diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index a1a756a9baef7..02caf9d6ce7c4 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -102,6 +103,9 @@ struct ProviderInfo { bool so_share_ep_contexts{false}; // ORT session option fs::path so_context_file_path{}; // ORT session option const ConfigOptions* config_options{NULL}; + const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", + "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", + "disable_dynamic_shapes"}; }; // Holds context applicable to the entire EP instance. diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index da12d7f27d61a..f9d4ab13cf2ce 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -58,57 +58,6 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const ProviderInfo& info, s shared_context_{shared_context}, ep_ctx_handle_{session_context_.openvino_sdk_version, *GetLogger()} { InitProviderOrtApi(); - - // to check if target device is available - // using OVCore capability GetAvailableDevices to fetch list of devices plugged in - if (info.cache_dir.empty()) { - bool all_devices_found = false; - // Checking for device_type configuration - if (info.device_type != "") { - std::vector devices_to_check; - if (info.device_type.find("HETERO:") == 0 || - info.device_type.find("MULTI:") == 0 || - info.device_type.find("BATCH:") == 0 || - info.device_type.find("AUTO:") == 0) { - auto delimit = info.device_type.find(":"); - const auto& devices = info.device_type.substr(delimit + 1); - devices_to_check = split(devices, ','); - } else { - devices_to_check.push_back(info.device_type); - } - - // Re-initialize before loop - all_devices_found = true; - for (const auto& device : devices_to_check) { - bool device_found = false; - std::string device_prefix = device; - int device_idx = 0; - // Get the index and remove the index from device_prefix - if (auto delimit = device_prefix.find("."); delimit != std::string::npos) { - try { - device_idx = std::stoi(device_prefix.substr(delimit + 1)); - } catch (std::exception& ex) { - ORT_THROW("[ERROR] [OpenVINO] Wrong index in specified device - " + device + " :", ex.what()); - } - device_prefix = device_prefix.substr(0, delimit); - } - std::vector available_devices = OVCore::Get()->GetAvailableDevices(device_prefix); - // If idx is 0, maybe index is not set (e.g. GPU) - // Then the device is found if we have at least one device of the type - if (device_idx == 0 && available_devices.size() >= 1) { - device_found = true; - } else { - // Find full device (e.g GPU.1) in the list - if (std::find(std::begin(available_devices), std::end(available_devices), device) != std::end(available_devices)) - device_found = true; - } - all_devices_found = all_devices_found && device_found; - } - } - if (!all_devices_found) { - ORT_THROW("[ERROR] [OpenVINO] Specified device - " + info.device_type + " is not available"); - } - } } OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index b8f9aec3a68ed..78c3be9cac35d 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -18,7 +18,6 @@ namespace onnxruntime { namespace openvino_ep { void ParseConfigOptions(ProviderInfo& pi) { - if (pi.config_options == nullptr) return; @@ -58,23 +57,29 @@ bool ParseBooleanOption(const ProviderOptions& provider_options, std::string opt return false; } -std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptions& provider_options, std::string option_name) { - // This function normally does not check if the selected device is available, but does some sanity checks - // Only if the device is not standard, then availability is checked. - // Availability is checked for the selected device in the OpenVINOExecutionProvider constructor - +std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptions& provider_options) { + std::set supported_device_types = {"CPU", "GPU", "NPU"}; + std::set supported_device_modes = {"AUTO", "HETERO", "MULTI"}; std::vector devices_to_check; std::string selected_device; - if (provider_options.contains(option_name)) { - selected_device = provider_options.at(option_name); - // If we have multiple device configuration, we need to check all of them - if ((selected_device.find("HETERO:") == 0) || - (selected_device.find("MULTI:") == 0) || - (selected_device.find("BATCH:") == 0) || - (selected_device.find("AUTO:") == 0)) { - auto delimit = selected_device.find(":"); - const auto& devices = selected_device.substr(delimit + 1); - devices_to_check = split(devices, ','); + std::vector luid_list; + std::string device_mode = ""; + std::map ov_luid_map; + + if (provider_options.contains("device_type")) { + selected_device = provider_options.at("device_type"); + std::erase(selected_device, ' '); + if (selected_device == "AUTO") return selected_device; + + if (auto delimit = selected_device.find(":"); delimit != std::string::npos) { + device_mode = selected_device.substr(0, delimit); + if (supported_device_modes.contains(device_mode)) { + const auto& devices = selected_device.substr(delimit + 1); + devices_to_check = split(devices, ','); + ORT_ENFORCE(devices_to_check.size() > 0, "Modes should have devices listed based on priority"); + } else { + ORT_THROW("[ERROR] [OpenVINO] Invalid device_type is selected. Supported modes are AUTO/HETERO/MULTI"); + } } else { devices_to_check.push_back(selected_device); } @@ -102,9 +107,18 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio #endif } - // Devices considered to be supported by default - std::unordered_set supported_device_types = {"CPU", "GPU", "NPU"}; + // Get the LUID passed from the provider option in a comma separated string list + // Compare each of the LUID's against the LUID obtained using ov property and map with the right device + if (provider_options.contains("device_luid")) { + std::string luid_str = provider_options.at("device_luid"); + std::erase(luid_str, ' '); + luid_list = split(luid_str, ','); + } + + bool all_devices_found = true; + for (auto device : devices_to_check) { + bool device_found = false; // Check deprecated device format (CPU_FP32, GPU.0_FP16, etc.) and remove the suffix in place // Suffix will be parsed in ParsePrecision if (auto delimit = device.find("_"); delimit != std::string::npos) { @@ -113,26 +127,57 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio // Just the device name without .0, .1, etc. suffix auto device_prefix = device; // Check if device index is appended (.0, .1, etc.), if so, remove it - if (auto delimit = device_prefix.find("."); delimit != std::string::npos) { + if (auto delimit = device_prefix.find("."); delimit != std::string::npos) device_prefix = device_prefix.substr(0, delimit); - if (device_prefix == "CPU") - ORT_THROW("[ERROR] [OpenVINO] CPU device is only supported without index, CPU.x is illegal.\n"); + if (supported_device_types.contains(device_prefix)) { + try { + std::vector available_devices = ov_core->GetAvailableDevices(device_prefix); + // Here we need to find the full device name (with .idx, but without _precision) + if (std::find(std::begin(available_devices), std::end(available_devices), device) != std::end(available_devices)) + device_found = true; + if (device_prefix != "CPU" && luid_list.size() > 0) { + for (auto dev : available_devices) { + ov::device::LUID ov_luid = OVCore::Get()->core.get_property(dev, ov::device::luid); + std::stringstream ov_luid_str; + ov_luid_str << ov_luid; + ov_luid_map.emplace(ov_luid_str.str(), dev); + } + } + } catch (const char* msg) { + ORT_THROW(msg); + } } - // Only device is not supported by default (some exotic device), check if it's available - if (!supported_device_types.contains(device_prefix)) { - std::vector available_devices = ov_core->GetAvailableDevices(); - // Here we need to find the full device name (with .idx, but without _precision) - if (std::find(std::begin(available_devices), std::end(available_devices), device) == std::end(available_devices)) { - ORT_THROW( - "[ERROR] [OpenVINO] You have selected wrong configuration value for the key 'device_type'. " - "Select from 'CPU', 'GPU', 'NPU', 'GPU.x' where x = 0,1,2 and so on or from" - " HETERO/MULTI/AUTO/BATCH options available. \n"); + all_devices_found = all_devices_found && device_found; + } + if (luid_list.size() > 0) { + std::string ov_luid_devices; + for (auto luid_str : luid_list) { + if (ov_luid_map.contains(luid_str)) { + if (!ov_luid_devices.empty()) ov_luid_devices = ov_luid_devices + ","; + ov_luid_devices = ov_luid_devices + ov_luid_map.at(luid_str); + } else { + ORT_THROW("Invalid device_luid is set"); + } + } + if (!device_mode.empty()) { + selected_device = device_mode + ":" + ov_luid_devices; + for (auto dev_str : devices_to_check) { + auto default_dev = split(dev_str, '.')[0]; + if (ov_luid_devices.find(default_dev) == std::string::npos) + selected_device = selected_device + "," + dev_str; } + } else { + selected_device = ov_luid_devices; } } - // All devices have passed the check, return selected device - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; - return selected_device; + // If invalid device is chosen error is thrown + if (!all_devices_found) + ORT_THROW( + "[ERROR] [OpenVINO] You have selected wrong configuration value for the key 'device_type'. " + "Select from 'CPU', 'GPU', 'NPU', 'GPU.x' where x = 0,1,2 and so on or from" + " HETERO/MULTI/AUTO/BATCH options available. \n"); + else + return selected_device; } void ParseProviderOptions([[maybe_unused]] ProviderInfo& result, [[maybe_unused]] const ProviderOptions& config_options) {} @@ -175,12 +220,22 @@ struct OpenVINO_Provider : Provider { ProviderInfo pi; pi.config_options = config_options; + // Lambda function to check for invalid keys and throw an error + auto validateKeys = [&]() { + for (const auto& pair : provider_options) { + if (pi.valid_provider_keys.find(pair.first) == pi.valid_provider_keys.end()) { + ORT_THROW("Invalid provider_option key: " + pair.first); + } + } + }; + validateKeys(); + std::string bool_flag = ""; // Minor optimization: we'll hold an OVCore reference to ensure we don't create a new core between ParseDeviceType and // (potential) SharedContext creation. auto ov_core = OVCore::Get(); - pi.device_type = ParseDeviceType(ov_core, provider_options, "device_type"); + pi.device_type = ParseDeviceType(ov_core, provider_options); if (provider_options.contains("device_id")) { std::string dev_id = provider_options.at("device_id").data(); @@ -303,12 +358,15 @@ struct OpenVINO_Provider : Provider { << "Executing with num_streams=1"; } } - pi.enable_opencl_throttling = ParseBooleanOption(provider_options, "enable_opencl_throttling"); + try { + pi.enable_opencl_throttling = ParseBooleanOption(provider_options, "enable_opencl_throttling"); - pi.enable_qdq_optimizer = ParseBooleanOption(provider_options, "enable_qdq_optimizer"); - - pi.disable_dynamic_shapes = ParseBooleanOption(provider_options, "disable_dynamic_shapes"); + pi.enable_qdq_optimizer = ParseBooleanOption(provider_options, "enable_qdq_optimizer"); + pi.disable_dynamic_shapes = ParseBooleanOption(provider_options, "disable_dynamic_shapes"); + } catch (std::string msg) { + ORT_THROW(msg); + } // Always true for NPU plugin or when passed . if (pi.device_type.find("NPU") != std::string::npos) { pi.disable_dynamic_shapes = true; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index d5d23cf4a11f1..fb9f4f4f97580 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -178,25 +178,31 @@ std::vector OVCore::GetAvailableDevices(const std::string& device_t } catch (const ov::Exception&) { // plugin is not created by e.g. invalid env // Empty device list will be returned - } catch (const std::runtime_error&) { + } catch (const std::runtime_error& ex) { // plugin is not created by e.g. invalid env // Empty device list will be returned + ORT_THROW("[ERROR] [OpenVINO] An exception occurred while trying to create the ", + device_type, + " device: ", + ex.what()); } catch (const std::exception& ex) { - ORT_THROW("[ERROR] [OpenVINO] An exception is thrown while trying to create the ", + ORT_THROW("[ERROR] [OpenVINO] An exception occurred while trying to create the ", device_type, " device: ", ex.what()); } catch (...) { - ORT_THROW("[ERROR] [OpenVINO] Unknown exception is thrown while trying to create the ", + ORT_THROW("[ERROR] [OpenVINO] Unknown exception occurred while trying to create the ", device_type, " device"); } - if (devicesIDs.size() > 1) { + if (devicesIDs.size() > 1 || + (devicesIDs.size() == 1 && devicesIDs[0] == "0")) { for (const auto& deviceID : devicesIDs) { available_devices.push_back(device_type + '.' + deviceID); } - } else if (!devicesIDs.empty()) { + } + if (!devicesIDs.empty()) { available_devices.push_back(device_type); } diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index effd13abc3e3a..636a1f8bfb500 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -386,7 +386,7 @@ static bool CheckQRuleSet(const NodeUnit& node_unit, } else if (op_type == "Add") { // Add keeps all Qs return true; - } else { + } else { // Keep Q of an unsupported Op only if the target that succeeds it is a supported Op in this list return IsNextTargetNodeOfQValid(q_node, &target_node, src_graph, {"Conv", "Add", "MatMul"}, false); } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 86d04fb4bbc2b..1630f63822b6a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -680,11 +680,11 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ov_options[key] = value; } else if (deprecated_device_types.find(value) != deprecated_device_types.end()) { ov_options[key] = value; - } else if (value.find("HETERO:") == 0) { + } else if (value.find("HETERO") == 0) { ov_options[key] = value; - } else if (value.find("MULTI:") == 0) { + } else if (value.find("MULTI") == 0) { ov_options[key] = value; - } else if (value.find("AUTO:") == 0) { + } else if (value.find("AUTO") == 0) { ov_options[key] = value; } else { ORT_THROW( @@ -792,6 +792,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } } else if (key == "device_memory_name") { device_memory_name_ = std::move(value); + } else if (key == "device_luid") { + ov_options[key] = value; } else { ORT_THROW( "[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO." From 2e4d5410db3c05cd718f111971a6bb3f59ca6033 Mon Sep 17 00:00:00 2001 From: saurabh Date: Mon, 7 Apr 2025 22:03:35 +0530 Subject: [PATCH 024/138] Enable adaptive stripping and eliminate dependency of weight sharing feature on OVEP qdq stripping (#629) * eliminate dependency of weight sharing on ovep qdq stripping pass * fix qdqnodeunit issue * enable compiler stripping * enable adaptive stripping: cleanup code * fix backward compatibility issue * add logs to identify which stripping is enabled * address PR review comments * fix unused variable error * resolve unused var issue * fix CI issues --- .../providers/openvino/backend_manager.cc | 26 ++++++++++++++---- .../core/providers/openvino/ov_interface.cc | 11 ++++++++ .../core/providers/openvino/ov_interface.h | 3 +++ .../qdq_transformations/qdq_stripping.cc | 27 ++++++++++++++----- .../qdq_transformations/qdq_stripping.h | 3 ++- 5 files changed, 57 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 1ed5ad8a56fd5..5d6251669f8ca 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -19,6 +19,7 @@ #include "core/providers/openvino/ibackend.h" #include "core/providers/openvino/backend_utils.h" #include "core/providers/openvino/qdq_transformations/qdq_stripping.h" +#include "core/providers/openvino/ov_interface.h" namespace onnxruntime { namespace openvino_ep { @@ -359,14 +360,29 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, } }; + [[maybe_unused]] bool enable_ovep_qdq_optimizer = session_context_.enable_qdq_optimizer && IsQDQGraph(subgraph); + [[maybe_unused]] std::optional enable_compiler_qdq_optimization = queryOVProperty("NPU_QDQ_OPTIMIZATION", session_context_.device_type); +#if (((OPENVINO_VERSION_MAJOR == 2025) && (OPENVINO_VERSION_MINOR > 0)) || (OPENVINO_VERSION_MAJOR > 2025)) + if (session_context_.device_type.find("NPU") != std::string::npos && session_context_.enable_qdq_optimizer) { + if (enable_compiler_qdq_optimization.has_value() && enable_compiler_qdq_optimization.value()) { + LOGS_DEFAULT(INFO) << "[OpenVINO-EP]: Compiler QDQ optimization pass is enabled"; + OVCore::Get()->core.set_property("NPU", {ov::intel_npu::qdq_optimization(true)}); + // disabling OVEP qdq stripping + // at this stage provider option "enable_qdq_optimizer" is still true but OVEP stripping is (disabled) false + // as compiler stripping is enabled + enable_ovep_qdq_optimizer = false; + } else { + LOGS_DEFAULT(INFO) << "[OpenVINO-EP]: OVEP QDQ optimization pass is enabled"; + } + } +#endif + const auto& onnx_model_path_name = subgraph.ModelPath(); // QDQ stripping enabled only for the NPU if (session_context_.device_type.find("NPU") != std::string::npos && - session_context_.enable_qdq_optimizer && - IsQDQGraph(subgraph)) { - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] QDQ optimization pass status: 1"; + (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; - Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, model, shared_context_.shared_weights); + Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, model, shared_context_.shared_weights, enable_ovep_qdq_optimizer); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); print_model_proto_duration(); @@ -374,7 +390,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; } else { - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] QDQ optimization pass status: 0"; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled"; auto model = subgraph.CreateModel(logger); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index fb9f4f4f97580..6afbd8ce761e5 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -46,6 +46,17 @@ void printDebugInfo(const ov::CompiledModel& obj) { } #endif +// Function to check if a given OV property is enabled +std::optional queryOVProperty(const std::string& property, const std::string& device_type) { + try { + // Get the property value + auto supported_properties = OVCore::Get()->core.get_property(device_type, ov::supported_properties); + return std::find(supported_properties.begin(), supported_properties.end(), property) != supported_properties.end(); + } catch (const std::exception&) { + return std::nullopt; // Property not found or invalid + } +} + std::shared_ptr OVCore::ReadModel(std::string&& model, const std::string& model_path) { try { std::istringstream modelStringStream(std::move(model)); diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 7e0f3f7d917a9..bebe73bd702dd 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -8,6 +8,7 @@ #include #include #include +#include #include "openvino/openvino.hpp" #include "openvino/runtime/intel_npu/properties.hpp" @@ -37,6 +38,8 @@ typedef ov::intel_gpu::ocl::ClContext* OVRemoteContextPtr; typedef ov::RemoteContext OVRemoteContext; #endif +std::optional queryOVProperty(const std::string& property, const std::string& device_type); + template class WeakSingleton { public: diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index 636a1f8bfb500..c071db9c3a4fb 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -341,6 +341,7 @@ static bool CheckDQRuleSet(const NodeUnit& node_unit, } } +// this check is if QLinear node feed into the output of src graph which expects quantized output static bool CheckQFeedsIntoQuantizedOutput(const NodeUnit& node_unit, const std::unordered_map graph_op_data_type) { auto op_of_quantized_layer = node_unit.Outputs(); @@ -447,9 +448,17 @@ static bool HandleDoubleQDQ(onnxruntime::Graph& dst_graph, const onnxruntime::Gr static void AddStandaloneNodeUnit(onnxruntime::Graph& dst_graph, const onnxruntime::GraphViewer& src_graph, const NodeUnit& node_unit, std::set& initializers_to_keep, - const logging::Logger& /* logger */) { + const logging::Logger& /* logger */, + bool IsWeightSharingWithoutOVEPQDQStripping) { assert(node_unit.UnitType() == NodeUnit::Type::SingleNode); + // this is the scenario where WAI is enabled and ovep stripping is disabled + // do not strip off any Q or DQ node + if (IsWeightSharingWithoutOVEPQDQStripping) { + AddNode(initializers_to_keep, src_graph, dst_graph, node_unit.GetNode()); + return; + } + if (HandleDoubleQDQ(dst_graph, src_graph, node_unit, initializers_to_keep)) return; auto add_identity_op = [&](bool duplicate_dq) { @@ -511,7 +520,8 @@ static void AddQDQNodeUnit(onnxruntime::Graph& dst_graph, const onnxruntime::GraphViewer& src_graph, const NodeUnit& node_unit, std::set& initializers_to_keep, - const logging::Logger& /* logger */) { + const logging::Logger& /* logger */, + bool IsWeightSharingWithoutOVEPQDQStripping) { assert(node_unit.UnitType() == NodeUnit::Type::QDQGroup); // Collect inputs coming into the node unit. @@ -529,7 +539,7 @@ static void AddQDQNodeUnit(onnxruntime::Graph& dst_graph, SkipReason reason = SkipReason::Other; bool keep_dq = CheckDQRuleSet(node_unit, dq_node, src_graph, reason); - if (keep_dq) { + if (IsWeightSharingWithoutOVEPQDQStripping || keep_dq) { AddNode(initializers_to_keep, src_graph, dst_graph, *dq_node); dq_node_args_to_keep.insert({input_defs.at(0)->Name(), &dst_graph.GetOrCreateNodeArg(dq_node->OutputDefs().at(0)->Name(), @@ -597,7 +607,7 @@ static void AddQDQNodeUnit(onnxruntime::Graph& dst_graph, bool keep_q = CheckQRuleSet(node_unit, q_node, src_graph, reason); - if (keep_q) { + if (IsWeightSharingWithoutOVEPQDQStripping || keep_q) { AddNode(initializers_to_keep, src_graph, dst_graph, *q_node); // if keep_q, then output defs of the target node doesn't change output_args.push_back(&dst_graph.GetOrCreateNodeArg(target_node.OutputDefs().at(i)->Name(), @@ -675,7 +685,8 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, bool enable_ovep_weight_sharing, /*out*/ std::unique_ptr& model, - /*out*/ sw& shared_weights) { + /*out*/ sw& shared_weights, + bool enable_ovep_qdq_optimizer) { // NOTE: This function is a re-implementation of GraphViewerToProto() in core/graph/graph_proto_serializer.cc // with the following differences: // - Uses onnxruntime::Graph APIs instead of onnx::GraphProto APIs. @@ -766,10 +777,12 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, continue; // Already handled this node unit } + bool IsWeightSharingWithoutOVEPQDQStripping = enable_ovep_weight_sharing && !enable_ovep_qdq_optimizer; + if (node_unit->UnitType() == NodeUnit::Type::SingleNode) { - AddStandaloneNodeUnit(dst_graph, src_graph, *node_unit, initializers_to_keep, logger); + AddStandaloneNodeUnit(dst_graph, src_graph, *node_unit, initializers_to_keep, logger, IsWeightSharingWithoutOVEPQDQStripping); } else { - AddQDQNodeUnit(dst_graph, src_graph, *node_unit, initializers_to_keep, logger); + AddQDQNodeUnit(dst_graph, src_graph, *node_unit, initializers_to_keep, logger, IsWeightSharingWithoutOVEPQDQStripping); } seen_node_units.insert(node_unit); diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h index 02831525cba32..4b5696f4411bd 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h @@ -17,7 +17,8 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, bool enable_ovep_weight_sharing, /*out*/ std::unique_ptr& model, - /*out*/ sw& shared_weights); + /*out*/ sw& shared_weights, + bool enable_ovep_qdq_optimizer); bool dumpMetaDataMapToBinary(const sw::Metadata::Map& shared_weights, const std::string& filename); } // namespace openvino_ep From c0c347ca2b06ff00faa7de8c862eca3f76341466 Mon Sep 17 00:00:00 2001 From: saurabh Date: Tue, 8 Apr 2025 15:05:24 +0530 Subject: [PATCH 025/138] Add Config for Release build * fix model save config * resolve unused variables error * fix model save for various configs in ovep * use generator exp to work with multi config build --- cmake/onnxruntime_providers_openvino.cmake | 6 ++++++ onnxruntime/core/providers/openvino/backend_manager.cc | 8 +++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index 143d002c6173e..f149030c15702 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -37,12 +37,18 @@ source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_openvino_cc_srcs}) onnxruntime_add_shared_library_module(onnxruntime_providers_openvino ${onnxruntime_providers_openvino_cc_srcs} "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc") + onnxruntime_add_include_to_target(onnxruntime_providers_openvino onnxruntime_common onnx nlohmann_json::nlohmann_json) install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/openvino/openvino_provider_factory.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) set_target_properties(onnxruntime_providers_openvino PROPERTIES CXX_STANDARD 20) set_target_properties(onnxruntime_providers_openvino PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_openvino PROPERTIES FOLDER "ONNXRuntime") + + target_compile_options(onnxruntime_providers_openvino PRIVATE + $<$>:-DNOT_RELEASE> + ) + if(NOT MSVC) target_compile_options(onnxruntime_providers_openvino PRIVATE "-Wno-parentheses") endif() diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 5d6251669f8ca..215cfafb2d174 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -321,9 +321,10 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { return false; } -static void DumpOpenVINOEPModel(const std::filesystem::path& onnx_model_path_name, - ONNX_NAMESPACE::ModelProto* model_proto, - const onnxruntime::Node& fused_node) { +static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, + [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, + [[maybe_unused]] const onnxruntime::Node& fused_node) { +#ifdef NOT_RELEASE if (openvino_ep::backend_utils::IsDebugEnabled()) { auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name.filename(); @@ -338,6 +339,7 @@ static void DumpOpenVINOEPModel(const std::filesystem::path& onnx_model_path_nam std::fstream dump(model_name, std::ios::out | std::ios::trunc | std::ios::binary); model_proto->SerializeToOstream(dump); } +#endif } std::unique_ptr From a8527b90ccf99af491004fcc423eb47ee95118b4 Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Tue, 8 Apr 2025 09:08:39 -0700 Subject: [PATCH 026/138] Bug fix in provider key verification (#644) --- onnxruntime/core/providers/openvino/contexts.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 02caf9d6ce7c4..1314edd54e937 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -103,7 +103,7 @@ struct ProviderInfo { bool so_share_ep_contexts{false}; // ORT session option fs::path so_context_file_path{}; // ORT session option const ConfigOptions* config_options{NULL}; - const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", + const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", "disable_dynamic_shapes"}; }; From 4e63ef6809c4d9d1ce76443c155d64e79e66a323 Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Wed, 9 Apr 2025 22:41:08 -0700 Subject: [PATCH 027/138] Fix the LUID check (#647) * Fix the LUID check * Address review comments --- .../openvino/openvino_provider_factory.cc | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 78c3be9cac35d..e35143df29941 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -76,7 +76,7 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio if (supported_device_modes.contains(device_mode)) { const auto& devices = selected_device.substr(delimit + 1); devices_to_check = split(devices, ','); - ORT_ENFORCE(devices_to_check.size() > 0, "Modes should have devices listed based on priority"); + ORT_ENFORCE(devices_to_check.size() > 0, "Mode AUTO/HETERO/MULTI should have devices listed based on priority"); } else { ORT_THROW("[ERROR] [OpenVINO] Invalid device_type is selected. Supported modes are AUTO/HETERO/MULTI"); } @@ -153,16 +153,24 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio std::string ov_luid_devices; for (auto luid_str : luid_list) { if (ov_luid_map.contains(luid_str)) { - if (!ov_luid_devices.empty()) ov_luid_devices = ov_luid_devices + ","; - ov_luid_devices = ov_luid_devices + ov_luid_map.at(luid_str); + std::string ov_dev = ov_luid_map.at(luid_str); + std::string ov_dev_strip = split(ov_dev, '.')[0]; + if (std::find(std::begin(devices_to_check), std::end(devices_to_check), ov_dev) != std::end(devices_to_check) || + std::find(std::begin(devices_to_check), std::end(devices_to_check), ov_dev_strip) != std::end(devices_to_check)) { + if (!ov_luid_devices.empty()) ov_luid_devices = ov_luid_devices + ","; + ov_luid_devices = ov_luid_devices + ov_dev; + } else { + ORT_THROW(" LUID : ", ov_dev, " does not match with device_type : ", selected_device); + } } else { - ORT_THROW("Invalid device_luid is set"); + ORT_THROW(provider_options.at("device_luid"), " does not exist for the selected device_type : ", selected_device); } } if (!device_mode.empty()) { selected_device = device_mode + ":" + ov_luid_devices; for (auto dev_str : devices_to_check) { auto default_dev = split(dev_str, '.')[0]; + if (ov_luid_devices.find(default_dev) == std::string::npos) selected_device = selected_device + "," + dev_str; } @@ -171,13 +179,15 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio } } // If invalid device is chosen error is thrown - if (!all_devices_found) + if (!all_devices_found) { ORT_THROW( "[ERROR] [OpenVINO] You have selected wrong configuration value for the key 'device_type'. " "Select from 'CPU', 'GPU', 'NPU', 'GPU.x' where x = 0,1,2 and so on or from" " HETERO/MULTI/AUTO/BATCH options available. \n"); - else + } else { + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; return selected_device; + } } void ParseProviderOptions([[maybe_unused]] ProviderInfo& result, [[maybe_unused]] const ProviderOptions& config_options) {} From 6fc0ed0c9525bb7a70e67c0dc9b911a3f0f1f6ae Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 Date: Fri, 11 Apr 2025 09:56:33 +0530 Subject: [PATCH 028/138] Update OV version for Intel Internal CI --- .github/workflows/internal_ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/internal_ci.yml b/.github/workflows/internal_ci.yml index 6ece42bb90571..638dbfb797591 100644 --- a/.github/workflows/internal_ci.yml +++ b/.github/workflows/internal_ci.yml @@ -46,4 +46,4 @@ jobs: run: | cd tools/ci_build/github/linux/ dir - ./run_dockerbuild.sh -o ubuntu22.04 -p 3.10 -d openvino -v 2025.0.0 -x "--config Release --use_openvino CPU --build_wheel --build_shared_lib --parallel " + ./run_dockerbuild.sh -o ubuntu22.04 -p 3.10 -d openvino -v 2025.1.0 -x "--config Release --use_openvino CPU --enable_generic_interface --build_wheel --build_shared_lib --parallel " From c2558f3a6f5998bd9de5b226c66531c795c87dd0 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Fri, 11 Apr 2025 20:15:25 +0530 Subject: [PATCH 029/138] [OVEP] Update ov version in ort (#653) --- .../github/azure-pipelines/linux-openvino-ci-pipeline.yml | 2 +- .../github/linux/docker/Dockerfile.ubuntu_openvino | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index 48627e656b9a8..6f122034e3698 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -33,5 +33,5 @@ jobs: parameters: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' - RunDockerBuildArgs: '-o ubuntu22.04 -p 3.10 -d openvino -v 2025.0.0 -x "--enable_generic_interface --use_openvino CPU --build_wheel"' + RunDockerBuildArgs: '-o ubuntu22.04 -p 3.10 -d openvino -v 2025.1.0 -x "--enable_generic_interface --use_openvino CPU --build_wheel"' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino index b53a2302be403..e8e4f22153ca5 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino @@ -1,7 +1,7 @@ ARG UBUNTU_VERSION=22.04 FROM ubuntu:${UBUNTU_VERSION} -ARG OPENVINO_VERSION=2025.0.0 +ARG OPENVINO_VERSION=2025.1.0 ARG PYTHON_VERSION=3.10 ADD scripts /tmp/scripts @@ -19,9 +19,9 @@ ENV IE_PLUGINS_PATH=$INTEL_OPENVINO_DIR/runtime/lib/intel64 ENV DEBIAN_FRONTEND=noninteractive RUN cd /opt && mkdir -p intel && cd intel && \ - wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.0/linux/openvino_toolkit_ubuntu22_2025.0.0.17942.1f68be9f594_x86_64.tgz && \ - tar xzf openvino_toolkit_ubuntu22_2025.0.0.17942.1f68be9f594_x86_64.tgz && rm -rf openvino_toolkit_ubuntu22_2025.0.0.17942.1f68be9f594_x86_64.tgz && \ - mv openvino_toolkit_ubuntu22_2025.0.0.17942.1f68be9f594_x86_64 openvino_2025.0.0 && \ + wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.1/linux/openvino_toolkit_ubuntu22_2025.1.0.18503.6fec06580ab_x86_64.tgz && \ + tar xzf openvino_toolkit_ubuntu22_2025.1.0.18503.6fec06580ab_x86_64.tgz && rm -rf openvino_toolkit_ubuntu22_2025.1.0.18503.6fec06580ab_x86_64.tgz && \ + mv openvino_toolkit_ubuntu22_2025.1.0.18503.6fec06580ab_x86_64 openvino_2025.1.0 && \ cd $INTEL_OPENVINO_DIR/install_dependencies && ./install_openvino_dependencies.sh -y WORKDIR /root From ad5f824d7645f9fdccaa65051f299a82660039e7 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Wed, 16 Apr 2025 23:32:56 -0700 Subject: [PATCH 030/138] Removing support for internal ci(Intel Internal) (#662) --- .github/workflows/internal_ci.yml | 49 ------------------------------- 1 file changed, 49 deletions(-) delete mode 100644 .github/workflows/internal_ci.yml diff --git a/.github/workflows/internal_ci.yml b/.github/workflows/internal_ci.yml deleted file mode 100644 index 638dbfb797591..0000000000000 --- a/.github/workflows/internal_ci.yml +++ /dev/null @@ -1,49 +0,0 @@ -name : Internal CI - -on: - pull_request_target: - branches: - - '**' # Triggers on a PR to any Branch - -permissions: - contents: read - pull-requests: read - -jobs: - build: - - if: github.event.pull_request.draft == false - runs-on: [self-hosted, Linux, X64] # Runs on a Lunar lake - env: - BUILD_SOURCESDIRECTORY: ${{ github.workspace }} - BUILD_BINARIESDIRECTORY: ${{ github.workspace }}/build - - steps: - - name: Check PR Author Authorization - run: | - if [[ "${{ github.event.pull_request.head.repo.full_name }}" != "${{ github.repository }}" ]]; then - echo "PR is from a fork: ${{ github.event.pull_request.head.repo.full_name }}" - fi - - - name: Checkout PR Branch - uses: actions/checkout@v4 - with: - ref: ${{ github.event.pull_request.head.ref }} - repository: ${{ github.event.pull_request.head.repo.full_name }} - fetch-depth: 1 # checkout the pr branch - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - - name: Create build directory - run: | - mkdir -p ${{ env.BUILD_BINARIESDIRECTORY }} - chmod -R 777 ${{ env.BUILD_BINARIESDIRECTORY }} - - - name: Running Internal CI # Trigger Internal CI on the pr branch - run: | - cd tools/ci_build/github/linux/ - dir - ./run_dockerbuild.sh -o ubuntu22.04 -p 3.10 -d openvino -v 2025.1.0 -x "--config Release --use_openvino CPU --enable_generic_interface --build_wheel --build_shared_lib --parallel " From 4c0acd8a18e2139a96f4c525bb419c1a97f0d99b Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Mon, 21 Apr 2025 00:22:08 -0700 Subject: [PATCH 031/138] Support for triggering Internal Ci(Intel Internal) (#665) --- .github/workflows/linux_openvino_ci_intel.yml | 45 +++++ .../workflows/reusable_linux_build_intel.yml | 183 ++++++++++++++++++ 2 files changed, 228 insertions(+) create mode 100644 .github/workflows/linux_openvino_ci_intel.yml create mode 100644 .github/workflows/reusable_linux_build_intel.yml diff --git a/.github/workflows/linux_openvino_ci_intel.yml b/.github/workflows/linux_openvino_ci_intel.yml new file mode 100644 index 0000000000000..985d014994877 --- /dev/null +++ b/.github/workflows/linux_openvino_ci_intel.yml @@ -0,0 +1,45 @@ +name: Linux OpenVINO CI + +on: + push: + branches: [ main, 'rel-*' ] + pull_request: + branches: ['**' ] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + packages: write # Needed if the reusable workflow pushes images + attestations: write # Optional: for artifact attestations if enabled + id-token: write # Optional: may be needed for OIDC authentication (e.g., ACR) + +jobs: + build_test_openvino: + name: Build and Test OpenVINO EP (AlamLinux8, Py3.12) + # Use the reusable workflow as the other Linux CI pipelines + uses: ./.github/workflows/reusable_linux_build_intel.yml + with: + pool_name: "onnxruntime-github-Ubuntu2204-AMD-CPU" + build_config: Release + # Architecture: OpenVino only supports Intel X64 + architecture: x64 + dockerfile_path: tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile + docker_image_repo: onnxruntimeopenvino + + execution_providers: 'openvino' + + extra_build_flags: '--use_openvino CPU --enable_generic_interface --build_shared_lib' + + # Python Path Prefix: Set the correct Python 3.12 path inside the manylinux container + python_path_prefix: 'PATH=/opt/python/cp312-cp312/bin:$PATH' + + run_tests: true + upload_build_output: false + + # Secrets: Pass the necessary GitHub token + secrets: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/reusable_linux_build_intel.yml b/.github/workflows/reusable_linux_build_intel.yml new file mode 100644 index 0000000000000..00859bb99d7f0 --- /dev/null +++ b/.github/workflows/reusable_linux_build_intel.yml @@ -0,0 +1,183 @@ +name: Reusable Linux CPU/GPU Build and Test + +on: + workflow_call: + inputs: + pool_name: + description: 'The specific 1ES pool name (e.g., onnxruntime-github-Ubuntu2204-AMD-CPU)' + required: true + type: string + build_config: + description: 'Build configuration (Debug or Release)' + required: true + type: string + architecture: + description: 'Target architecture (x64 or arm64)' + required: true + type: string + dockerfile_path: + description: 'Path to the Dockerfile relative to the workspace root' + required: true + type: string + docker_image_repo: + description: 'Name for the Docker image repository' + required: true + type: string + docker_build_args: + description: 'Arguments to pass to the docker image build command' + required: false + type: string + default: '' + execution_providers: + description: 'Space-separated list of execution providers to enable (passed to build.py)' + required: false + type: string + default: '' + extra_build_flags: + description: 'Additional flags for the build.py script (appended after EP flags)' + required: false + type: string + default: '' + python_path_prefix: + description: 'Optional prefix to add to the PATH for python command (e.g., PATH=/opt/python/cp310-cp310/bin:$PATH)' + required: false + type: string + default: '' + python_version: + description: 'Python version to set up on the runner host' + required: false + type: string + default: '3.x' + run_tests: + description: 'Whether to execute the test suite after building' + required: false + type: boolean + default: true + upload_build_output: + description: 'Whether to upload the build output directory as an artifact (used when tests are skipped)' + required: false + type: boolean + default: false + secrets: + GH_TOKEN: + description: 'GitHub token for accessing actions/packages' + required: true + +jobs: + build_test_pipeline: + runs-on: [self-hosted, Linux, X64] + permissions: + contents: read + packages: write + attestations: write + id-token: write + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ inputs.python_version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python_version }} + + - name: Build Docker Image (${{ inputs.architecture }} / ${{ inputs.build_config }}) + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.5 + id: build_docker_image_step + with: + dockerfile: ${{ github.workspace }}/${{ inputs.dockerfile_path }} + image-name: ghcr.io/microsoft/onnxruntime/${{ inputs.docker_image_repo }} + build-args: ${{ inputs.docker_build_args }} + push: true + azure-container-registry-name: onnxruntimebuildcache + env: + GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + # ------------- Update Step (CMake Generation) ------------- + - name: Generate Build Files (CMake) (${{ inputs.architecture }} / ${{ inputs.build_config }}) + id: update_step + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} + build_config: ${{ inputs.build_config }} + mode: 'update' + execution_providers: ${{ inputs.execution_providers }} # Pass down EP list + extra_build_flags: ${{ inputs.extra_build_flags }} + python_path_prefix: ${{ inputs.python_path_prefix }} + + # ------------- Build Step (Compilation) ------------- + - name: Build ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) + id: build_step + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} + build_config: ${{ inputs.build_config }} + mode: 'build' + execution_providers: ${{ inputs.execution_providers }} # Pass down EP list + extra_build_flags: ${{ inputs.extra_build_flags }} + python_path_prefix: ${{ inputs.python_path_prefix }} + + # ------------- Test Step ------------- + - name: Test ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) + id: test_step + if: inputs.run_tests == true + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} + build_config: ${{ inputs.build_config }} + mode: 'test' + execution_providers: ${{ inputs.execution_providers }} # Pass down EP list + extra_build_flags: ${{ inputs.extra_build_flags }} + python_path_prefix: ${{ inputs.python_path_prefix }} + + # ------------- Prepare Artifact Step ------------- + - name: Prepare Build Output for Upload + if: inputs.upload_build_output == true + shell: bash + run: | + #!/bin/bash + set -e -x + BUILD_DIR="${{ runner.temp }}/${{ inputs.build_config }}" + if [ ! -d "${BUILD_DIR}" ]; then + echo "Error: Build directory ${BUILD_DIR} not found. Cannot prepare artifact." + exit 1 + fi + echo "--- Cleaning build directory: ${BUILD_DIR} ---" + rm -rf "${BUILD_DIR}/onnxruntime" || true + rm -rf "${BUILD_DIR}/pybind11" || true + rm -rf "${BUILD_DIR}/vcpkg_installed" || true + rm -f "${BUILD_DIR}/models" || true + DEPS_DIR="${BUILD_DIR}/_deps" + if [ -d "${DEPS_DIR}" ]; then + echo "Cleaning ${DEPS_DIR}, keeping onnx-src..." + find "${DEPS_DIR}" -mindepth 1 ! -regex "^${DEPS_DIR}/onnx-src\(/.*\)?$" -delete + else + echo "${DEPS_DIR} does not exist, skipping deps cleanup." + fi + echo "--- Saving executable permissions ---" + cd "${BUILD_DIR}" + find . -executable -type f -printf '%p\n' > perms.txt + echo "--- Cleanup and permission saving complete for ${BUILD_DIR} ---" + + # ------------- Upload Build Output Step ------------- + - name: Upload Build Output Artifact + if: inputs.upload_build_output == true + uses: actions/upload-artifact@v4 + with: + name: build-output-${{ inputs.architecture }}-${{ inputs.build_config }} + path: ${{ runner.temp }}/${{ inputs.build_config }} + if-no-files-found: error + + # ------------- Upload Log on Build Failure Step ------------- + - name: Upload VCPKG Manifest Install Log on Update or Build Failure + if: steps.update_step.outcome == 'failure' || steps.build_step.outcome == 'failure' + uses: actions/upload-artifact@v4 + with: + name: vcpkg-manifest-install-log-${{ inputs.architecture }}-${{ inputs.build_config }} + path: ${{ runner.temp }}/${{ inputs.build_config }}/${{ inputs.build_config }}/vcpkg-manifest-install.log + if-no-files-found: ignore From 740133526dfabdb0752b70afda48e3945c24eee7 Mon Sep 17 00:00:00 2001 From: saurabh Date: Mon, 21 Apr 2025 01:16:02 -0700 Subject: [PATCH 032/138] incorporate requested changes for PR:24394 (#661) Co-authored-by: sfatimar --- cmake/onnxruntime_providers_openvino.cmake | 2 +- .../core/providers/openvino/backend_manager.cc | 12 ++++++------ .../qdq_transformations/qdq_stripping.cc | 16 ++++++++-------- .../openvino/qdq_transformations/qdq_stripping.h | 4 ++-- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index f149030c15702..03f67983c70ab 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -46,7 +46,7 @@ set_target_properties(onnxruntime_providers_openvino PROPERTIES FOLDER "ONNXRuntime") target_compile_options(onnxruntime_providers_openvino PRIVATE - $<$>:-DNOT_RELEASE> + $<$:-DRELEASE> ) if(NOT MSVC) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index d758430f39108..13f09b9d9acdb 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -13,13 +13,13 @@ #include #include "core/providers/shared_library/provider_api.h" -#include "core/providers/openvino/ov_versions/capability.h" -#include "core/providers/openvino/contexts.h" #include "core/providers/openvino/backend_manager.h" -#include "core/providers/openvino/ibackend.h" #include "core/providers/openvino/backend_utils.h" -#include "core/providers/openvino/qdq_transformations/qdq_stripping.h" +#include "core/providers/openvino/contexts.h" +#include "core/providers/openvino/ibackend.h" #include "core/providers/openvino/ov_interface.h" +#include "core/providers/openvino/ov_versions/capability.h" +#include "core/providers/openvino/qdq_transformations/qdq_stripping.h" namespace onnxruntime { namespace openvino_ep { @@ -324,7 +324,7 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, [[maybe_unused]] const onnxruntime::Node& fused_node) { -#ifdef NOT_RELEASE +#ifndef RELEASE if (openvino_ep::backend_utils::IsDebugEnabled()) { auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name.filename(); @@ -384,7 +384,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, if (session_context_.device_type.find("NPU") != std::string::npos && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; - Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, model, shared_context_.shared_weights, enable_ovep_qdq_optimizer); + Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); print_model_proto_duration(); diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index c071db9c3a4fb..860cfb5713903 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -448,8 +448,8 @@ static bool HandleDoubleQDQ(onnxruntime::Graph& dst_graph, const onnxruntime::Gr static void AddStandaloneNodeUnit(onnxruntime::Graph& dst_graph, const onnxruntime::GraphViewer& src_graph, const NodeUnit& node_unit, std::set& initializers_to_keep, - const logging::Logger& /* logger */, - bool IsWeightSharingWithoutOVEPQDQStripping) { + bool IsWeightSharingWithoutOVEPQDQStripping, + const logging::Logger& /* logger */) { assert(node_unit.UnitType() == NodeUnit::Type::SingleNode); // this is the scenario where WAI is enabled and ovep stripping is disabled @@ -520,8 +520,8 @@ static void AddQDQNodeUnit(onnxruntime::Graph& dst_graph, const onnxruntime::GraphViewer& src_graph, const NodeUnit& node_unit, std::set& initializers_to_keep, - const logging::Logger& /* logger */, - bool IsWeightSharingWithoutOVEPQDQStripping) { + bool IsWeightSharingWithoutOVEPQDQStripping, + const logging::Logger& /* logger */) { assert(node_unit.UnitType() == NodeUnit::Type::QDQGroup); // Collect inputs coming into the node unit. @@ -684,9 +684,9 @@ static void AddInitializerAsInput(onnxruntime::Graph& dst_graph, Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, bool enable_ovep_weight_sharing, + bool enable_ovep_qdq_optimizer, /*out*/ std::unique_ptr& model, - /*out*/ sw& shared_weights, - bool enable_ovep_qdq_optimizer) { + /*out*/ sw& shared_weights) { // NOTE: This function is a re-implementation of GraphViewerToProto() in core/graph/graph_proto_serializer.cc // with the following differences: // - Uses onnxruntime::Graph APIs instead of onnx::GraphProto APIs. @@ -780,9 +780,9 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, bool IsWeightSharingWithoutOVEPQDQStripping = enable_ovep_weight_sharing && !enable_ovep_qdq_optimizer; if (node_unit->UnitType() == NodeUnit::Type::SingleNode) { - AddStandaloneNodeUnit(dst_graph, src_graph, *node_unit, initializers_to_keep, logger, IsWeightSharingWithoutOVEPQDQStripping); + AddStandaloneNodeUnit(dst_graph, src_graph, *node_unit, initializers_to_keep, IsWeightSharingWithoutOVEPQDQStripping, logger); } else { - AddQDQNodeUnit(dst_graph, src_graph, *node_unit, initializers_to_keep, logger, IsWeightSharingWithoutOVEPQDQStripping); + AddQDQNodeUnit(dst_graph, src_graph, *node_unit, initializers_to_keep, IsWeightSharingWithoutOVEPQDQStripping, logger); } seen_node_units.insert(node_unit); diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h index 4b5696f4411bd..53de0fd019311 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h @@ -16,9 +16,9 @@ using sw = SharedContext::SharedWeights; Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, bool enable_ovep_weight_sharing, + bool enable_ovep_qdq_optimizer, /*out*/ std::unique_ptr& model, - /*out*/ sw& shared_weights, - bool enable_ovep_qdq_optimizer); + /*out*/ sw& shared_weights); bool dumpMetaDataMapToBinary(const sw::Metadata::Map& shared_weights, const std::string& filename); } // namespace openvino_ep From 269f6fe2b87950f7e5716a7839f8fa68a498039a Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Wed, 23 Apr 2025 00:19:10 -0700 Subject: [PATCH 033/138] Add support for session option ep.stop_context_sharing (#655) * Add function to query external initializer file name * Decouple external weight processing from shared context and add support for stop context sharing --- .../providers/openvino/backend_manager.cc | 38 ++++---- .../core/providers/openvino/backend_manager.h | 1 + .../core/providers/openvino/backend_utils.cc | 90 ++++++++++++++----- .../core/providers/openvino/backend_utils.h | 9 +- .../openvino/backends/basic_backend.cc | 6 +- .../core/providers/openvino/contexts.h | 65 ++++++-------- .../openvino/openvino_execution_provider.cc | 12 ++- .../openvino/openvino_provider_factory.cc | 1 + .../qdq_transformations/qdq_stripping.cc | 21 +++-- .../qdq_transformations/qdq_stripping.h | 8 +- 10 files changed, 152 insertions(+), 99 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 13f09b9d9acdb..139a0eac512a4 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -83,22 +83,23 @@ BackendManager::BackendManager(SessionContext& session_context, } std::string device_type = session_context_.device_type; - auto& sw = shared_context_.shared_weights; - if (session_context_.so_share_ep_contexts) { - std::filesystem::path weight_filename = session_context_.onnx_model_path_name.parent_path(); - if (sw.external_weight_filename.empty() && !sw.metadata.empty()) { - // Reasonable assumption that all metadata entries have the same external file location - sw.external_weight_filename = sw.metadata.begin()->second.location; - } - weight_filename /= sw.external_weight_filename; - std::ifstream weight_file(weight_filename); + // Check if model is using external weights + if (auto filename = backend_utils::GetExternalWeightFilename(subgraph)) { + std::filesystem::path weights_filepath = session_context_.onnx_model_path_name.parent_path() / filename.value(); - if (weight_file) { - if (!sw.mapped_weights) { - sw.mapped_weights = std::make_unique(weight_filename); - } - backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights); + // Initialize external weights with fully qualified path + if (!std::filesystem::exists(weights_filepath)) { + ORT_THROW("Error: Failed to locate weight file at ", weights_filepath.string()); } + + external_weights_.emplace(weights_filepath); + } + + if (session_context_.so_share_ep_contexts) { + ORT_ENFORCE(external_weights_.has_value(), "Expected external weight object to be valid"); + backend_utils::CreateOVTensors(session_context_.device_type, + shared_context_.shared_weights.metadata, + external_weights_.value()); } if (ModelHasSymbolicInputDims(subgraph)) { @@ -324,7 +325,7 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, [[maybe_unused]] const onnxruntime::Node& fused_node) { -#ifndef RELEASE +#ifdef NOT_RELEASE if (openvino_ep::backend_utils::IsDebugEnabled()) { auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name.filename(); @@ -384,7 +385,12 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, if (session_context_.device_type.find("NPU") != std::string::npos && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; - Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights); + Status status = CreateModelWithStrippedQDQNodes(subgraph, + logger, + session_context_.so_share_ep_contexts, + enable_ovep_qdq_optimizer, + model, + shared_context_.shared_weights.metadata); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); print_model_proto_duration(); diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index cdc27701ec2e6..22936acf3ea66 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -54,6 +54,7 @@ class BackendManager { EPCtxHandler& ep_ctx_handle_; SessionContext& session_context_; SharedContext& shared_context_; + std::optional external_weights_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 2ee5e9ec3e3a9..58309d37877f1 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -20,22 +21,7 @@ using Exception = ov::Exception; namespace onnxruntime { namespace openvino_ep { -SharedContext::SharedWeights::WeightsFile::WeightsFile(std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary) { - try { - file_.exceptions(std::ifstream::failbit | std::ifstream::badbit); - weights_size_ = file_.seekg(0, std::ios::end).tellg(); - } catch (std::ifstream::failure& e) { - ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what()); - } -} - -void SharedContext::SharedWeights::WeightsFile::load_weights(size_t file_offset, void* data, size_t size) { - ORT_ENFORCE(file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), "Error: File offset is out of bounds."); - file_.seekg(file_offset); - file_.read(reinterpret_cast(data), size); -} - -std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) { +std::ostream& operator<<(std::ostream& stream, const Metadata::Map& metadata) { try { stream << metadata.size(); @@ -69,14 +55,14 @@ std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeight return stream; } -std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Metadata::Map& metadata) { +std::istream& operator>>(std::istream& stream, Metadata::Map& metadata) { size_t map_size{0}; try { stream >> map_size; while (!stream.eof()) { - SharedContext::SharedWeights::Metadata::Key key; - SharedContext::SharedWeights::Metadata::Value value; + Metadata::Key key; + Metadata::Value value; stream >> key.name; stream >> value.location; stream >> value.data_offset; @@ -399,8 +385,19 @@ ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt // Function to handle tensor creation from external data void CreateOVTensors(const std::string& device_name, - SharedContext::SharedWeights::Metadata::Map& metadata_map, - SharedContext::SharedWeights::WeightsFile& weights) { + Metadata::Map& metadata_map, + std::filesystem::path& weights_filepath) { + // File is guaranteed to exist at this point + std::ifstream file(weights_filepath, std::ios::in | std::ios::binary); + file.exceptions(std::ifstream::failbit | std::ifstream::badbit); + size_t weights_size = std::filesystem::file_size(weights_filepath); + + const auto load_weights = [&file, weights_size](size_t file_offset, void* data, size_t size) { + ORT_ENFORCE(file_offset < weights_size && size <= weights_size && (file_offset <= weights_size - size), "Error: File offset is out of bounds."); + file.seekg(file_offset); + file.read(reinterpret_cast(data), size); + }; + for (auto& [key, value] : metadata_map) { if (value.tensor) continue; @@ -416,18 +413,18 @@ void CreateOVTensors(const std::string& device_name, auto&& remote_tensor = npu_context.create_l0_host_tensor(ov_elementType, value.dimensions, ov::intel_npu::TensorType::INPUT); // Copy data to remote tensor - weights.load_weights(value.data_offset, remote_tensor.get(), value.size); + load_weights(value.data_offset, remote_tensor.get(), value.size); value.tensor = std::make_shared(remote_tensor); } else { // Use vanilla tensors value.tensor = std::make_shared(ov_elementType, value.dimensions); - weights.load_weights(value.data_offset, value.tensor->data(), value.size); + load_weights(value.data_offset, value.tensor->data(), value.size); } ORT_ENFORCE(value.tensor->get_byte_size() == value.size, "Unexpected tensor size mismatch"); } } -void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) { +void DestroyOVTensors(Metadata::Map& metadata_map) { for (auto& [key, value] : metadata_map) { if (value.tensor) { value.tensor.reset(); @@ -436,6 +433,51 @@ void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) metadata_map.clear(); } +std::optional GetExternalWeightFilename(const GraphViewer& graph) { + auto get_external_location = [](const ONNX_NAMESPACE::TensorProto& proto) -> std::optional { + using mutable_proto_t = ONNX_NAMESPACE::TensorProto*; + auto& mutable_proto = *const_cast(&proto); + auto* entry_protos = mutable_proto.mutable_external_data(); + + if (proto.has_data_location() && proto.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + for (int i = 0; i < entry_protos->size(); i++) { + auto& string_entry_proto{entry_protos->at(i)}; + const auto& pb_key{*(string_entry_proto.mutable_key())}; + const auto& pb_value{*(string_entry_proto.mutable_value())}; + if (pb_key == "location") { + return std::make_optional(pb_value); + } + } + } + + return std::nullopt; + }; + + // Handle constant initializers + auto& initializers = graph.GetAllInitializedTensors(); + for (const auto& it : initializers) { + if (auto result = get_external_location(*it.second)) { + return result; + } + } + + // Handle outer-scope constant initializers + for (auto& node_idx : graph.GetNodesInTopologicalOrder()) { + const auto& node = graph.GetNode(node_idx); + for (const auto& input : node->InputDefs()) { + if (graph.IsConstantInitializer(input->Name(), true)) { + const auto& initializer_tensor = *graph.GetConstantInitializer(input->Name(), true); + + if (auto result = get_external_location(initializer_tensor)) { + return result; + } + } + } + } + + return std::nullopt; +} + } // namespace backend_utils } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index f13b1b05ced67..b56c5e6e7f6ef 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -67,15 +67,18 @@ CreateOVModel(std::string&& model, std::map>& const_outputs_map); void CreateOVTensors(const std::string& device_name, - SharedContext::SharedWeights::Metadata::Map& metadata_map, - SharedContext::SharedWeights::WeightsFile& weights); -void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map); + Metadata::Map& metadata_map, + std::filesystem::path& weights_filepath); +void DestroyOVTensors(Metadata::Map& metadata_map); void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName); void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName); +// Returns the location string from the first external initializer nodes found or nullopt if none found +std::optional GetExternalWeightFilename(const GraphViewer& graph); + } // namespace backend_utils } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index c814df618e3b3..c11f853dd1122 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -125,10 +125,12 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr std::function initializer = [](OVInferRequestPtr) {}; auto metadata = shared_context_.shared_weights.metadata; if (session_context_.so_share_ep_contexts) { + // When shared ep contexts is set external weight references are transformed to model inputs. This + // creates an initializer to populate/bind input weight tensors to each inference request initializer = [&metadata](OVInferRequestPtr ir_ptr) { const auto input_count = ir_ptr->GetNumInputs(); for (auto i = 0u; i < input_count; i++) { - using Key = SharedContext::SharedWeights::Metadata::Key; + using Key = Metadata::Key; const auto tensor_key = Key{ir_ptr->GetInputTensorName(i)}; if (metadata.contains(tensor_key)) { auto& value = metadata.at(tensor_key); @@ -137,6 +139,8 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr } }; } + + // Create inference request queue and initialize according to passed function inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, num_infer_req, std::move(initializer))); } diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 1314edd54e937..eaac62036e21e 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -18,6 +18,29 @@ namespace openvino_ep { namespace fs = std::filesystem; +struct Metadata { + struct Key { + std::string name; + bool operator==(const Key&) const = default; + }; + struct Hash { + std::size_t operator()(const Key& key) const noexcept { + return std::hash()(key.name); + } + }; + struct Value { + std::string location; + unsigned int data_offset; + unsigned int size; + std::vector dimensions; + std::int32_t element_type; + std::shared_ptr tensor; + }; + using Map = std::unordered_map; + friend std::ostream& operator<<(std::ostream& right, const Metadata::Map& metadata); + friend std::istream& operator>>(std::istream& right, Metadata::Map& metadata); +}; + class SharedContext : public WeakSingleton { // Keep the core alive as long as the shared SharedContext are alive. std::shared_ptr OVCore_; @@ -25,45 +48,12 @@ class SharedContext : public WeakSingleton { public: SharedContext() : OVCore_(OVCore::Get()) {} struct SharedWeights { - struct Metadata { - struct Key { - std::string name; - bool operator==(const Key&) const = default; - }; - struct Hash { - std::size_t operator()(const Key& key) const noexcept { - return std::hash()(key.name); - } - }; - struct Value { - std::string location; - unsigned int data_offset; - unsigned int size; - std::vector dimensions; - std::int32_t element_type; - std::shared_ptr tensor; - }; - using Map = std::unordered_map; - friend std::ostream& operator<<(std::ostream& right, const Metadata::Map& metadata); - friend std::istream& operator>>(std::istream& right, Metadata::Map& metadata); - }; - - struct WeightsFile { - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightsFile); - WeightsFile() = delete; - explicit WeightsFile(std::filesystem::path filename); - - void load_weights(size_t file_offset, void* data, size_t size); - - private: - std::ifstream file_; - size_t weights_size_; - }; - - fs::path external_weight_filename; - std::unique_ptr mapped_weights; Metadata::Map metadata; } shared_weights; + + void clear() { // Deletes the data stored in the SharedContext + shared_weights.metadata.clear(); + } }; using config_t = std::map; @@ -102,6 +92,7 @@ struct ProviderInfo { bool so_context_embed_mode{false}; // ORT session option bool so_share_ep_contexts{false}; // ORT session option fs::path so_context_file_path{}; // ORT session option + bool so_stop_share_ep_contexts{false}; // ORT session option const ConfigOptions* config_options{NULL}; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index f9d4ab13cf2ce..767b6519f1387 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -65,6 +65,7 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { backend_manager.ShutdownBackendManager(); } backend_managers_.clear(); + shared_context_.reset(); } std::vector> @@ -106,7 +107,12 @@ common::Status OpenVINOExecutionProvider::Compile( auto& metadata = shared_context_->shared_weights.metadata; if (session_context_.so_share_ep_contexts && metadata.empty()) { // Metadata is always read from model location, this could be a source or epctx model - fs::path metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin"; + fs::path metadata_filename; + if (session_context_.so_context_file_path.empty()) { + metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin"; + } else { + metadata_filename = session_context_.so_context_file_path.parent_path() / "metadata.bin"; + } std::ifstream file(metadata_filename, std::ios::binary); if (file) { file >> metadata; @@ -191,6 +197,10 @@ common::Status OpenVINOExecutionProvider::Compile( } } + if (session_context_.so_stop_share_ep_contexts) { + shared_context_->clear(); + } + return status; } diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index e36ff48d0351d..4548cec6eadb0 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -26,6 +26,7 @@ void ParseConfigOptions(ProviderInfo& pi) { pi.so_context_embed_mode = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; pi.so_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; pi.so_context_file_path = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + pi.so_stop_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionStopShareEpContexts, "0") == "1"; if (pi.so_share_ep_contexts) { ov::AnyMap map; diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index 860cfb5713903..61040c5552c71 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/qdq_transformations/qdq_stripping.h" @@ -683,10 +684,10 @@ static void AddInitializerAsInput(onnxruntime::Graph& dst_graph, // Creates a new model without the DQ/Q operators in the src graph. Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, - bool enable_ovep_weight_sharing, + bool transform_weight_as_input, bool enable_ovep_qdq_optimizer, /*out*/ std::unique_ptr& model, - /*out*/ sw& shared_weights) { + /*out*/ Metadata::Map& weight_metadata) { // NOTE: This function is a re-implementation of GraphViewerToProto() in core/graph/graph_proto_serializer.cc // with the following differences: // - Uses onnxruntime::Graph APIs instead of onnx::GraphProto APIs. @@ -777,7 +778,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, continue; // Already handled this node unit } - bool IsWeightSharingWithoutOVEPQDQStripping = enable_ovep_weight_sharing && !enable_ovep_qdq_optimizer; + bool IsWeightSharingWithoutOVEPQDQStripping = transform_weight_as_input && !enable_ovep_qdq_optimizer; if (node_unit->UnitType() == NodeUnit::Type::SingleNode) { AddStandaloneNodeUnit(dst_graph, src_graph, *node_unit, initializers_to_keep, IsWeightSharingWithoutOVEPQDQStripping, logger); @@ -802,11 +803,9 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, std::sort(const_inits.begin(), const_inits.end()); // initialize map for creating metadata for initilizers with external weights - auto& metadata = shared_weights.metadata; - - const auto& insert_metadata = [&metadata](const ONNX_NAMESPACE::TensorProto& proto) { - sw::Metadata::Map::key_type key{proto.name()}; - sw::Metadata::Map::mapped_type value{}; + const auto& insert_metadata = [&weight_metadata](const ONNX_NAMESPACE::TensorProto& proto) { + Metadata::Map::key_type key{proto.name()}; + Metadata::Map::mapped_type value{}; using mutable_proto_t = ONNX_NAMESPACE::TensorProto*; auto& mutable_proto = *const_cast(&proto); @@ -829,7 +828,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, dim = proto.dims()[index++]; } - metadata.emplace(key, std::move(value)); + weight_metadata.emplace(key, std::move(value)); }; // Handle constant initializers @@ -839,7 +838,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, // Check if the initializer has external data if (initializer_tensor.has_data_location() && initializer_tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL && - enable_ovep_weight_sharing) { + transform_weight_as_input) { insert_metadata(initializer_tensor); // Add initializer with external data as input @@ -867,7 +866,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, // Check if the initializer has external data if (initializer_tensor.has_data_location() && initializer_tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL && - enable_ovep_weight_sharing) { + transform_weight_as_input) { insert_metadata(initializer_tensor); // Add initializer as input if it has external data diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h index 53de0fd019311..7e87352e5992d 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h @@ -10,16 +10,12 @@ namespace onnxruntime { namespace openvino_ep { -using sw = SharedContext::SharedWeights; - // Creates a new model without the DQ/Q operators in the src graph as per pre-defined rulesets Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, - bool enable_ovep_weight_sharing, + bool transform_weight_as_input, bool enable_ovep_qdq_optimizer, /*out*/ std::unique_ptr& model, - /*out*/ sw& shared_weights); - -bool dumpMetaDataMapToBinary(const sw::Metadata::Map& shared_weights, const std::string& filename); + /*out*/ Metadata::Map& metadata); } // namespace openvino_ep } // namespace onnxruntime From 9ee331a90acf6fac4cc530060e3a596f15eb3570 Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Thu, 24 Apr 2025 18:54:08 +0530 Subject: [PATCH 034/138] Revert "Add support for session option ep.stop_context_sharing (#655)" (#674) This reverts commit 269f6fe2b87950f7e5716a7839f8fa68a498039a. --- .../providers/openvino/backend_manager.cc | 38 ++++---- .../core/providers/openvino/backend_manager.h | 1 - .../core/providers/openvino/backend_utils.cc | 90 +++++-------------- .../core/providers/openvino/backend_utils.h | 9 +- .../openvino/backends/basic_backend.cc | 6 +- .../core/providers/openvino/contexts.h | 65 ++++++++------ .../openvino/openvino_execution_provider.cc | 12 +-- .../openvino/openvino_provider_factory.cc | 1 - .../qdq_transformations/qdq_stripping.cc | 21 ++--- .../qdq_transformations/qdq_stripping.h | 8 +- 10 files changed, 99 insertions(+), 152 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 139a0eac512a4..13f09b9d9acdb 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -83,23 +83,22 @@ BackendManager::BackendManager(SessionContext& session_context, } std::string device_type = session_context_.device_type; - // Check if model is using external weights - if (auto filename = backend_utils::GetExternalWeightFilename(subgraph)) { - std::filesystem::path weights_filepath = session_context_.onnx_model_path_name.parent_path() / filename.value(); - - // Initialize external weights with fully qualified path - if (!std::filesystem::exists(weights_filepath)) { - ORT_THROW("Error: Failed to locate weight file at ", weights_filepath.string()); + auto& sw = shared_context_.shared_weights; + if (session_context_.so_share_ep_contexts) { + std::filesystem::path weight_filename = session_context_.onnx_model_path_name.parent_path(); + if (sw.external_weight_filename.empty() && !sw.metadata.empty()) { + // Reasonable assumption that all metadata entries have the same external file location + sw.external_weight_filename = sw.metadata.begin()->second.location; } + weight_filename /= sw.external_weight_filename; + std::ifstream weight_file(weight_filename); - external_weights_.emplace(weights_filepath); - } - - if (session_context_.so_share_ep_contexts) { - ORT_ENFORCE(external_weights_.has_value(), "Expected external weight object to be valid"); - backend_utils::CreateOVTensors(session_context_.device_type, - shared_context_.shared_weights.metadata, - external_weights_.value()); + if (weight_file) { + if (!sw.mapped_weights) { + sw.mapped_weights = std::make_unique(weight_filename); + } + backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights); + } } if (ModelHasSymbolicInputDims(subgraph)) { @@ -325,7 +324,7 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, [[maybe_unused]] const onnxruntime::Node& fused_node) { -#ifdef NOT_RELEASE +#ifndef RELEASE if (openvino_ep::backend_utils::IsDebugEnabled()) { auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name.filename(); @@ -385,12 +384,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, if (session_context_.device_type.find("NPU") != std::string::npos && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; - Status status = CreateModelWithStrippedQDQNodes(subgraph, - logger, - session_context_.so_share_ep_contexts, - enable_ovep_qdq_optimizer, - model, - shared_context_.shared_weights.metadata); + Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); print_model_proto_duration(); diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index 22936acf3ea66..cdc27701ec2e6 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -54,7 +54,6 @@ class BackendManager { EPCtxHandler& ep_ctx_handle_; SessionContext& session_context_; SharedContext& shared_context_; - std::optional external_weights_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 58309d37877f1..2ee5e9ec3e3a9 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -21,7 +20,22 @@ using Exception = ov::Exception; namespace onnxruntime { namespace openvino_ep { -std::ostream& operator<<(std::ostream& stream, const Metadata::Map& metadata) { +SharedContext::SharedWeights::WeightsFile::WeightsFile(std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary) { + try { + file_.exceptions(std::ifstream::failbit | std::ifstream::badbit); + weights_size_ = file_.seekg(0, std::ios::end).tellg(); + } catch (std::ifstream::failure& e) { + ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what()); + } +} + +void SharedContext::SharedWeights::WeightsFile::load_weights(size_t file_offset, void* data, size_t size) { + ORT_ENFORCE(file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), "Error: File offset is out of bounds."); + file_.seekg(file_offset); + file_.read(reinterpret_cast(data), size); +} + +std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) { try { stream << metadata.size(); @@ -55,14 +69,14 @@ std::ostream& operator<<(std::ostream& stream, const Metadata::Map& metadata) { return stream; } -std::istream& operator>>(std::istream& stream, Metadata::Map& metadata) { +std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Metadata::Map& metadata) { size_t map_size{0}; try { stream >> map_size; while (!stream.eof()) { - Metadata::Key key; - Metadata::Value value; + SharedContext::SharedWeights::Metadata::Key key; + SharedContext::SharedWeights::Metadata::Value value; stream >> key.name; stream >> value.location; stream >> value.data_offset; @@ -385,19 +399,8 @@ ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt // Function to handle tensor creation from external data void CreateOVTensors(const std::string& device_name, - Metadata::Map& metadata_map, - std::filesystem::path& weights_filepath) { - // File is guaranteed to exist at this point - std::ifstream file(weights_filepath, std::ios::in | std::ios::binary); - file.exceptions(std::ifstream::failbit | std::ifstream::badbit); - size_t weights_size = std::filesystem::file_size(weights_filepath); - - const auto load_weights = [&file, weights_size](size_t file_offset, void* data, size_t size) { - ORT_ENFORCE(file_offset < weights_size && size <= weights_size && (file_offset <= weights_size - size), "Error: File offset is out of bounds."); - file.seekg(file_offset); - file.read(reinterpret_cast(data), size); - }; - + SharedContext::SharedWeights::Metadata::Map& metadata_map, + SharedContext::SharedWeights::WeightsFile& weights) { for (auto& [key, value] : metadata_map) { if (value.tensor) continue; @@ -413,18 +416,18 @@ void CreateOVTensors(const std::string& device_name, auto&& remote_tensor = npu_context.create_l0_host_tensor(ov_elementType, value.dimensions, ov::intel_npu::TensorType::INPUT); // Copy data to remote tensor - load_weights(value.data_offset, remote_tensor.get(), value.size); + weights.load_weights(value.data_offset, remote_tensor.get(), value.size); value.tensor = std::make_shared(remote_tensor); } else { // Use vanilla tensors value.tensor = std::make_shared(ov_elementType, value.dimensions); - load_weights(value.data_offset, value.tensor->data(), value.size); + weights.load_weights(value.data_offset, value.tensor->data(), value.size); } ORT_ENFORCE(value.tensor->get_byte_size() == value.size, "Unexpected tensor size mismatch"); } } -void DestroyOVTensors(Metadata::Map& metadata_map) { +void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) { for (auto& [key, value] : metadata_map) { if (value.tensor) { value.tensor.reset(); @@ -433,51 +436,6 @@ void DestroyOVTensors(Metadata::Map& metadata_map) { metadata_map.clear(); } -std::optional GetExternalWeightFilename(const GraphViewer& graph) { - auto get_external_location = [](const ONNX_NAMESPACE::TensorProto& proto) -> std::optional { - using mutable_proto_t = ONNX_NAMESPACE::TensorProto*; - auto& mutable_proto = *const_cast(&proto); - auto* entry_protos = mutable_proto.mutable_external_data(); - - if (proto.has_data_location() && proto.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - for (int i = 0; i < entry_protos->size(); i++) { - auto& string_entry_proto{entry_protos->at(i)}; - const auto& pb_key{*(string_entry_proto.mutable_key())}; - const auto& pb_value{*(string_entry_proto.mutable_value())}; - if (pb_key == "location") { - return std::make_optional(pb_value); - } - } - } - - return std::nullopt; - }; - - // Handle constant initializers - auto& initializers = graph.GetAllInitializedTensors(); - for (const auto& it : initializers) { - if (auto result = get_external_location(*it.second)) { - return result; - } - } - - // Handle outer-scope constant initializers - for (auto& node_idx : graph.GetNodesInTopologicalOrder()) { - const auto& node = graph.GetNode(node_idx); - for (const auto& input : node->InputDefs()) { - if (graph.IsConstantInitializer(input->Name(), true)) { - const auto& initializer_tensor = *graph.GetConstantInitializer(input->Name(), true); - - if (auto result = get_external_location(initializer_tensor)) { - return result; - } - } - } - } - - return std::nullopt; -} - } // namespace backend_utils } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index b56c5e6e7f6ef..f13b1b05ced67 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -67,18 +67,15 @@ CreateOVModel(std::string&& model, std::map>& const_outputs_map); void CreateOVTensors(const std::string& device_name, - Metadata::Map& metadata_map, - std::filesystem::path& weights_filepath); -void DestroyOVTensors(Metadata::Map& metadata_map); + SharedContext::SharedWeights::Metadata::Map& metadata_map, + SharedContext::SharedWeights::WeightsFile& weights); +void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map); void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName); void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName); -// Returns the location string from the first external initializer nodes found or nullopt if none found -std::optional GetExternalWeightFilename(const GraphViewer& graph); - } // namespace backend_utils } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index c11f853dd1122..c814df618e3b3 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -125,12 +125,10 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr std::function initializer = [](OVInferRequestPtr) {}; auto metadata = shared_context_.shared_weights.metadata; if (session_context_.so_share_ep_contexts) { - // When shared ep contexts is set external weight references are transformed to model inputs. This - // creates an initializer to populate/bind input weight tensors to each inference request initializer = [&metadata](OVInferRequestPtr ir_ptr) { const auto input_count = ir_ptr->GetNumInputs(); for (auto i = 0u; i < input_count; i++) { - using Key = Metadata::Key; + using Key = SharedContext::SharedWeights::Metadata::Key; const auto tensor_key = Key{ir_ptr->GetInputTensorName(i)}; if (metadata.contains(tensor_key)) { auto& value = metadata.at(tensor_key); @@ -139,8 +137,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr } }; } - - // Create inference request queue and initialize according to passed function inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, num_infer_req, std::move(initializer))); } diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index c0c4551607202..7560f4570bd32 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -18,29 +18,6 @@ namespace openvino_ep { namespace fs = std::filesystem; -struct Metadata { - struct Key { - std::string name; - bool operator==(const Key&) const = default; - }; - struct Hash { - std::size_t operator()(const Key& key) const noexcept { - return std::hash()(key.name); - } - }; - struct Value { - std::string location; - unsigned int data_offset; - unsigned int size; - std::vector dimensions; - std::int32_t element_type; - std::shared_ptr tensor; - }; - using Map = std::unordered_map; - friend std::ostream& operator<<(std::ostream& right, const Metadata::Map& metadata); - friend std::istream& operator>>(std::istream& right, Metadata::Map& metadata); -}; - class SharedContext : public WeakSingleton { // Keep the core alive as long as the shared SharedContext are alive. std::shared_ptr OVCore_; @@ -48,12 +25,45 @@ class SharedContext : public WeakSingleton { public: SharedContext() : OVCore_(OVCore::Get()) {} struct SharedWeights { + struct Metadata { + struct Key { + std::string name; + bool operator==(const Key&) const = default; + }; + struct Hash { + std::size_t operator()(const Key& key) const noexcept { + return std::hash()(key.name); + } + }; + struct Value { + std::string location; + unsigned int data_offset; + unsigned int size; + std::vector dimensions; + std::int32_t element_type; + std::shared_ptr tensor; + }; + using Map = std::unordered_map; + friend std::ostream& operator<<(std::ostream& right, const Metadata::Map& metadata); + friend std::istream& operator>>(std::istream& right, Metadata::Map& metadata); + }; + + struct WeightsFile { + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightsFile); + WeightsFile() = delete; + explicit WeightsFile(std::filesystem::path filename); + + void load_weights(size_t file_offset, void* data, size_t size); + + private: + std::ifstream file_; + size_t weights_size_; + }; + + fs::path external_weight_filename; + std::unique_ptr mapped_weights; Metadata::Map metadata; } shared_weights; - - void clear() { // Deletes the data stored in the SharedContext - shared_weights.metadata.clear(); - } }; using config_t = std::map; @@ -92,7 +102,6 @@ struct ProviderInfo { bool so_context_embed_mode{false}; // ORT session option bool so_share_ep_contexts{false}; // ORT session option fs::path so_context_file_path{}; // ORT session option - bool so_stop_share_ep_contexts{false}; // ORT session option const ConfigOptions* config_options{NULL}; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 767b6519f1387..f9d4ab13cf2ce 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -65,7 +65,6 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { backend_manager.ShutdownBackendManager(); } backend_managers_.clear(); - shared_context_.reset(); } std::vector> @@ -107,12 +106,7 @@ common::Status OpenVINOExecutionProvider::Compile( auto& metadata = shared_context_->shared_weights.metadata; if (session_context_.so_share_ep_contexts && metadata.empty()) { // Metadata is always read from model location, this could be a source or epctx model - fs::path metadata_filename; - if (session_context_.so_context_file_path.empty()) { - metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin"; - } else { - metadata_filename = session_context_.so_context_file_path.parent_path() / "metadata.bin"; - } + fs::path metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin"; std::ifstream file(metadata_filename, std::ios::binary); if (file) { file >> metadata; @@ -197,10 +191,6 @@ common::Status OpenVINOExecutionProvider::Compile( } } - if (session_context_.so_stop_share_ep_contexts) { - shared_context_->clear(); - } - return status; } diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 93ec08b88ae21..f7f15dc62fd11 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -28,7 +28,6 @@ void ParseConfigOptions(ProviderInfo& pi) { pi.so_context_embed_mode = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; pi.so_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; pi.so_context_file_path = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); - pi.so_stop_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionStopShareEpContexts, "0") == "1"; if (pi.so_share_ep_contexts) { ov::AnyMap map; diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index 61040c5552c71..860cfb5713903 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -11,7 +11,6 @@ #include #include #include -#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/qdq_transformations/qdq_stripping.h" @@ -684,10 +683,10 @@ static void AddInitializerAsInput(onnxruntime::Graph& dst_graph, // Creates a new model without the DQ/Q operators in the src graph. Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, - bool transform_weight_as_input, + bool enable_ovep_weight_sharing, bool enable_ovep_qdq_optimizer, /*out*/ std::unique_ptr& model, - /*out*/ Metadata::Map& weight_metadata) { + /*out*/ sw& shared_weights) { // NOTE: This function is a re-implementation of GraphViewerToProto() in core/graph/graph_proto_serializer.cc // with the following differences: // - Uses onnxruntime::Graph APIs instead of onnx::GraphProto APIs. @@ -778,7 +777,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, continue; // Already handled this node unit } - bool IsWeightSharingWithoutOVEPQDQStripping = transform_weight_as_input && !enable_ovep_qdq_optimizer; + bool IsWeightSharingWithoutOVEPQDQStripping = enable_ovep_weight_sharing && !enable_ovep_qdq_optimizer; if (node_unit->UnitType() == NodeUnit::Type::SingleNode) { AddStandaloneNodeUnit(dst_graph, src_graph, *node_unit, initializers_to_keep, IsWeightSharingWithoutOVEPQDQStripping, logger); @@ -803,9 +802,11 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, std::sort(const_inits.begin(), const_inits.end()); // initialize map for creating metadata for initilizers with external weights - const auto& insert_metadata = [&weight_metadata](const ONNX_NAMESPACE::TensorProto& proto) { - Metadata::Map::key_type key{proto.name()}; - Metadata::Map::mapped_type value{}; + auto& metadata = shared_weights.metadata; + + const auto& insert_metadata = [&metadata](const ONNX_NAMESPACE::TensorProto& proto) { + sw::Metadata::Map::key_type key{proto.name()}; + sw::Metadata::Map::mapped_type value{}; using mutable_proto_t = ONNX_NAMESPACE::TensorProto*; auto& mutable_proto = *const_cast(&proto); @@ -828,7 +829,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, dim = proto.dims()[index++]; } - weight_metadata.emplace(key, std::move(value)); + metadata.emplace(key, std::move(value)); }; // Handle constant initializers @@ -838,7 +839,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, // Check if the initializer has external data if (initializer_tensor.has_data_location() && initializer_tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL && - transform_weight_as_input) { + enable_ovep_weight_sharing) { insert_metadata(initializer_tensor); // Add initializer with external data as input @@ -866,7 +867,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, // Check if the initializer has external data if (initializer_tensor.has_data_location() && initializer_tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL && - transform_weight_as_input) { + enable_ovep_weight_sharing) { insert_metadata(initializer_tensor); // Add initializer as input if it has external data diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h index 7e87352e5992d..53de0fd019311 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h @@ -10,12 +10,16 @@ namespace onnxruntime { namespace openvino_ep { +using sw = SharedContext::SharedWeights; + // Creates a new model without the DQ/Q operators in the src graph as per pre-defined rulesets Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, - bool transform_weight_as_input, + bool enable_ovep_weight_sharing, bool enable_ovep_qdq_optimizer, /*out*/ std::unique_ptr& model, - /*out*/ Metadata::Map& metadata); + /*out*/ sw& shared_weights); + +bool dumpMetaDataMapToBinary(const sw::Metadata::Map& shared_weights, const std::string& filename); } // namespace openvino_ep } // namespace onnxruntime From a077c79ad98270144bbb064fc09762658cfeb954 Mon Sep 17 00:00:00 2001 From: Bartlomiej Filipek Date: Mon, 28 Apr 2025 23:06:25 -0700 Subject: [PATCH 035/138] Release model proto after we have the serialized string to reduce peak memory consumption (#672) Signed-off-by: bfilipek --- onnxruntime/core/providers/openvino/backends/basic_backend.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index c814df618e3b3..c7ea76fabe815 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -98,9 +98,11 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr auto_unified_compile) { // Unified OV compile_model is efficient when ov model caching is enabled // Unified OV compile_model API is supported with AUTO from version 2024.3 and above - // Inputs with static dimenstions + // Inputs with static dimensions // Not enabled for models with external weights and when ep context is set. const std::string model = model_proto->SerializeAsString(); + // we have the serialized string, so we can release model proto to lower the peak memory consumption + model_proto.reset(); exe_network_ = OVCore::Get()->CompileModel(model, hw_target, device_config, From 8d2f3c41bd29479ac6dd9e19354d28c1c955df53 Mon Sep 17 00:00:00 2001 From: bopeng1234 Date: Wed, 14 May 2025 15:35:49 +0800 Subject: [PATCH 036/138] add channel wise quantization option for QDQ, and opt for intel NPU (#669) * add channel wise quantization option for QDQ, it optimize for intel NPU * add channel_wised_quantize args to MatMulNBitsQuantizer --- .../quantization/matmul_nbits_quantizer.py | 61 ++++++++++++++++++- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index b1d58b713eea8..f70a7b545e60a 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -202,6 +202,7 @@ def __init__( op_types_to_quantize: tuple[str, ...] | None = None, quant_axes: tuple[tuple[str, int], ...] | None = None, bits: int = 4, + channel_wised_quantize: bool = False, ): """ This is a class for weight only affine quantization configuration. @@ -236,6 +237,9 @@ def __init__( self.is_symmetric = is_symmetric self.bits = bits self.accuracy_level = accuracy_level + self.channel_wised_quantize = channel_wised_quantize + if channel_wised_quantize and quant_format == QuantFormat.QOperator: + raise NotImplementedError("QuantFormat.QOperator is not supported channel_wised_quantize yet") class NVAWQWeightOnlyQuantConfig(WeightOnlyQuantConfig): @@ -734,6 +738,26 @@ def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, Gr return None, None +# transpose int4 matrix (packed as uint8) +def transpose_packed_int4_matrix(packed, rows, cols): + # unpack to int4 matrix + total = rows * cols + high = (packed >> 4) & 0x0F + low = packed & 0x0F + int4_vals = np.empty(total, dtype=np.uint8) + int4_vals[0::2] = low + int4_vals[1::2] = high + int4_matrix = int4_vals.reshape((rows, cols)) + + # transpose int4 matrix + int4_matrix_transposed = int4_matrix.T + + # pack to uint8 + flat = int4_matrix_transposed.reshape(-1) + packed = ((flat[1::2] << 4) & 0xF0) | (flat[0::2] & 0x0F) + return packed.astype(np.uint8) + + class DefaultWeightOnlyQuantizer: def __init__(self, config: DefaultWeightOnlyQuantConfig): self.config = config @@ -770,6 +794,10 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric ) else: + # block size equal to rows (K) if channel wised quantize enabled + block_size = rows if self.config.channel_wised_quantize else self.config.block_size + k_blocks = (rows + block_size - 1) // block_size + assert qbits == 4, "QDQ format only support 4 bits quantization" packed = np.zeros((rows * cols + 1) // 2, dtype="uint8") zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") @@ -812,6 +840,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis ) scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") + # if QDQ, CW and SYM enabled, optimize for Intel NPU, tranpose the weight to NHWC format will increase performance + qdq_opt_for_intel_npu_enabled = self.config.quant_format == QuantFormat.QDQ \ + and self.config.channel_wised_quantize and self.config.is_symmetric + if qdq_opt_for_intel_npu_enabled: + rows, cols = b_ndarray.shape + packed = transpose_packed_int4_matrix(packed, rows, cols) + scales = scales.reshape((cols, 1)) # (cols, 1) + b_quant = onnx.helper.make_tensor(b_tensor.name + f"_DQ_Q{bits}", qtype, [cols, rows], packed.tobytes(), True) + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") + for input in b_graph.input: if input.name == input_b: b_graph.input.remove(input) @@ -849,7 +887,9 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis else: dq_input_names = [b_quant.name, scales_tensor.name] dq_output_names = [b_quant.name + "_output"] - matmul_input_names = [node.input[0], dq_output_names[0]] + tp_input_names = [dq_output_names[0]] + tp_output_names = [dq_output_names[0] + "_transposed"] + matmul_input_names = [node.input[0], tp_output_names[0] if qdq_opt_for_intel_npu_enabled else dq_output_names[0]] matmul_output_names = [node.output[0]] if not self.config.is_symmetric: zp_tensor = onnx.helper.make_tensor( @@ -857,7 +897,11 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis ) dq_input_names.append(zp_tensor.name) b_graph.initializer.extend([zp_tensor]) - dq_kwargs = {"axis": 0, "block_size": self.config.block_size} + rows, cols = b_ndarray.shape + dq_kwargs = { + "axis": 1 if qdq_opt_for_intel_npu_enabled else 0, + "block_size": rows if self.config.channel_wised_quantize else self.config.block_size + } dq_node = onnx.helper.make_node( "DequantizeLinear", inputs=dq_input_names, @@ -871,7 +915,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis outputs=matmul_output_names, name=node.name + f"_matmul_Q{bits}" if node.name else "", ) - output_nodes.extend([dq_node, matmul_node]) + if qdq_opt_for_intel_npu_enabled: + tp_node = onnx.helper.make_node( + "Transpose", + inputs=tp_input_names, + outputs=tp_output_names, + perm=[1,0], + ) + output_nodes.extend([dq_node, tp_node, matmul_node]) + else: + output_nodes.extend([dq_node, matmul_node]) return output_nodes @@ -1136,6 +1189,7 @@ def __init__( quant_format=QuantFormat.QOperator, op_types_to_quantize: tuple[str, ...] | None = None, quant_axes: tuple[tuple[str, int], ...] | None = None, + channel_wised_quantize: bool = False, algo_config: WeightOnlyQuantConfig | None = None, ): if nodes_to_exclude is None: @@ -1158,6 +1212,7 @@ def __init__( op_types_to_quantize=op_types_to_quantize, quant_axes=quant_axes, bits=4, # default to 4 bits + channel_wised_quantize=channel_wised_quantize, ) self.algo_config = algo_config From 76d312246159a4d4267e4de001c774319cec8e2a Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Thu, 15 May 2025 13:18:07 -0700 Subject: [PATCH 037/138] Don't include initializers in compute capability (#686) Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- onnxruntime/core/providers/openvino/ov_versions/capability.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index d56687f868c3d..bbe5d5a4b966c 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -126,9 +126,6 @@ std::vector> GetCapability::Execute() { } } - // Initializers need to be part of meta_def->inputs - Iterable2String(inputs, ng_required_initializers); - // Fill outputs with names Iterable2String(outputs, graph_viewer_.GetOutputs()); From 8d57482dcadfd610561089df06f3735c28c67114 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Mon, 19 May 2025 16:57:19 +0530 Subject: [PATCH 038/138] [OVEP] Fixed coverity issues (#693) --- .../openvino/openvino_execution_provider.cc | 2 +- .../openvino/openvino_provider_factory.cc | 14 +++++++------- .../core/providers/openvino/ov_allocator.cc | 12 ++++++++---- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index f9d4ab13cf2ce..70bc64c0f65bc 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -55,7 +55,7 @@ static std::vector parseDevices(const std::string& device_string, OpenVINOExecutionProvider::OpenVINOExecutionProvider(const ProviderInfo& info, std::shared_ptr shared_context) : IExecutionProvider{onnxruntime::kOpenVINOExecutionProvider}, session_context_(info), - shared_context_{shared_context}, + shared_context_{std::move(shared_context)}, ep_ctx_handle_{session_context_.openvino_sdk_version, *GetLogger()} { InitProviderOrtApi(); } diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index f7f15dc62fd11..e5526ecd52bb9 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -138,7 +138,7 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio if (std::find(std::begin(available_devices), std::end(available_devices), device) != std::end(available_devices)) device_found = true; if (device_prefix != "CPU" && luid_list.size() > 0) { - for (auto dev : available_devices) { + for (const auto& dev : available_devices) { ov::device::LUID ov_luid = OVCore::Get()->core.get_property(dev, ov::device::luid); std::stringstream ov_luid_str; ov_luid_str << ov_luid; @@ -153,7 +153,7 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio } if (luid_list.size() > 0) { std::string ov_luid_devices; - for (auto luid_str : luid_list) { + for (const auto& luid_str : luid_list) { if (ov_luid_map.contains(luid_str)) { std::string ov_dev = ov_luid_map.at(luid_str); std::string ov_dev_strip = split(ov_dev, '.')[0]; @@ -170,14 +170,14 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio } if (!device_mode.empty()) { selected_device = device_mode + ":" + ov_luid_devices; - for (auto dev_str : devices_to_check) { - auto default_dev = split(dev_str, '.')[0]; + for (const auto& dev_str : devices_to_check) { + const auto default_dev = split(dev_str, '.')[0]; if (ov_luid_devices.find(default_dev) == std::string::npos) selected_device = selected_device + "," + dev_str; } } else { - selected_device = ov_luid_devices; + selected_device = std::move(ov_luid_devices); } } // If invalid device is chosen error is thrown @@ -215,7 +215,7 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, // Minor optimization: we'll hold an OVCore reference to ensure we don't create a new core between ParseDeviceType and // (potential) SharedContext creation. auto ov_core = OVCore::Get(); - pi.device_type = ParseDeviceType(ov_core, provider_options); + pi.device_type = ParseDeviceType(std::move(ov_core), provider_options); if (provider_options.contains("device_id")) { std::string dev_id = provider_options.at("device_id").data(); @@ -355,7 +355,7 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, struct OpenVINOProviderFactory : IExecutionProviderFactory { OpenVINOProviderFactory(ProviderInfo provider_info, std::shared_ptr shared_context) - : provider_info_(std::move(provider_info)), shared_context_(shared_context) {} + : provider_info_(std::move(provider_info)), shared_context_(std::move(shared_context)) {} ~OpenVINOProviderFactory() override {} diff --git a/onnxruntime/core/providers/openvino/ov_allocator.cc b/onnxruntime/core/providers/openvino/ov_allocator.cc index 431f5730c0342..1bbe71441cbff 100644 --- a/onnxruntime/core/providers/openvino/ov_allocator.cc +++ b/onnxruntime/core/providers/openvino/ov_allocator.cc @@ -22,7 +22,7 @@ void* OVRTAllocator::Alloc(size_t size) { try { ov::Tensor* tensor = new ov::Tensor(remote_ctx_.create_host_tensor(ov::element::Type_t::u8, {size})); - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); allocated_.insert({tensor->data(), tensor}); return reinterpret_cast(tensor->data()); } catch (const ov::Exception& e) { @@ -32,12 +32,16 @@ void* OVRTAllocator::Alloc(size_t size) { void OVRTAllocator::Free(void* p) { try { - std::unique_lock lock(mutex_); + ov::Tensor* tensor = nullptr; + { + std::lock_guard lock(mutex_); auto it = allocated_.find(p); if (it != allocated_.end()) { - ov::Tensor* tensor = it->second; + tensor = it->second; allocated_.erase(it); - lock.unlock(); + } + } + if (tensor) { delete tensor; } } catch (const ov::Exception& e) { From 599cd25909bc7d7455f0bf55c5aff7d26501106d Mon Sep 17 00:00:00 2001 From: Bartlomiej Filipek Date: Fri, 23 May 2025 02:15:23 -0700 Subject: [PATCH 039/138] [GPU] Enable qdq_stripping path for GPU (#694) * update the statement so that we run CreateModelWithStrippedQDQNodes on GPU * ensure the capability checks are also updated * update the comment Signed-off-by: bfilipek --------- Signed-off-by: bfilipek --- onnxruntime/core/providers/openvino/backend_manager.cc | 5 +++-- .../core/providers/openvino/ov_versions/capability.cc | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 9ef7e4b86db5f..cf8e11826ce8b 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -380,8 +380,9 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, #endif const auto& onnx_model_path_name = subgraph.ModelPath(); - // QDQ stripping enabled only for the NPU - if (session_context_.device_type.find("NPU") != std::string::npos && + // QDQ stripping enabled only for the NPU and experimentally on the GPU + if ((session_context_.device_type.find("NPU") != std::string::npos || + session_context_.device_type.find("GPU") != std::string::npos) && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights); diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index bbe5d5a4b966c..46d2f6e02c70e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -34,7 +34,7 @@ GetCapability::GetCapability(const EPCtxHandler& ep_ctx_handler, graph_viewer_(graph_viewer_param), device_type_(std::move(device_type_param)) { bool npu_qdq_optimizer_enabled = false; - if (device_type_.find("NPU") != std::string::npos) { + if (device_type_.find("NPU") != std::string::npos || device_type_.find("GPU") != std::string::npos) { device_type_ = "CPU"; if (enable_qdq_optimizer) npu_qdq_optimizer_enabled = true; } From 1a22a5724cedf758a75840949db105e2dd4efd54 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Mon, 26 May 2025 06:05:49 -0700 Subject: [PATCH 040/138] Optimize CPU time spent in inference path (#682) * Optimize CPU time spent in inference path Move input/output name to ort/ov input output bindings to compilation. Reduce tensor lookups by name in favor of index look ups. * Fix dynamic shape handling --------- Co-authored-by: Preetha Veeramalai --- .../core/providers/openvino/backend_utils.cc | 4 +- .../openvino/backends/basic_backend.cc | 136 +++++------------- .../openvino/backends/basic_backend.h | 46 +++++- 3 files changed, 85 insertions(+), 101 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 2ee5e9ec3e3a9..1382c187f6b4e 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -121,7 +121,7 @@ std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Met namespace backend_utils { bool IsDebugEnabled() { - const std::string env_name = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_ENABLE_DEBUG"); + static std::string env_name = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_ENABLE_DEBUG"); if (!env_name.empty()) { return true; } @@ -129,7 +129,7 @@ bool IsDebugEnabled() { } bool IsCILogEnabled() { - const std::string env_name = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_ENABLE_CI_LOG"); + static std::string env_name = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_ENABLE_CI_LOG"); if (!env_name.empty()) { return true; } diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index c7ea76fabe815..e77ff973f3a87 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -140,6 +140,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr }; } inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, num_infer_req, std::move(initializer))); + bindings_ = std::make_unique(exe_network_, subgraph_context_); } bool BasicBackend::ValidateSubgraph(std::map>& const_outputs_map) { @@ -362,29 +363,16 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) { // an Infer Request indexed by infer_req_idx void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { try { - auto ov_input_info = exe_network_.Get().inputs(); - - // Loop over subgraph original input names to find the correspondent OV input name - for (const auto& [onnx_input_name, onnx_input_index] : subgraph_context_.input_names) { - std::string input_name{}; - uint32_t input_idx = 0; - for (uint32_t index = 0; const auto& ov_input : ov_input_info) { - if (ov_input.get_names().contains(onnx_input_name)) { - input_name = onnx_input_name; - input_idx = index; - break; - } - index++; - } - ORT_ENFORCE(!input_name.empty(), log_tag, - "Input names mismatch between OpenVINO and ONNX. ", onnx_input_name, - " doesn't exist in the list of OpenVINO input tensor names"); + bool cpu_or_gpu = (session_context_.device_type.find("CPU") != std::string::npos || + session_context_.device_type.find("GPU") != std::string::npos); + bool npu = (session_context_.device_type.find("NPU") != std::string::npos); + + for (const auto& input_info : bindings_->network_inputs_) { size_t batch_slice_idx = 0; if (subgraph_context_.has_dynamic_input_shape && !session_context_.disable_dynamic_shapes && - (session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { - auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); + cpu_or_gpu) { + auto tensor = context.GetInput(input_info.onnx_index); auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); auto tensor_shape = tensor_info.GetShape(); auto tensor_size = tensor_shape.size(); @@ -395,98 +383,72 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque input_tensor_shape[tensor_iter] = *i; tensor_iter += 1; } - const auto& input = ov_input_info.at(input_idx); OVTensorPtr tensor_ptr; // avoid input copies on the CPU device if (session_context_.device_type.find("CPU") != std::string::npos) { - tensor_ptr = std::make_shared(input.get_element_type(), input_tensor_shape, + tensor_ptr = std::make_shared(input_info.type, input_tensor_shape, (void*)tensor_data); } else { - tensor_ptr = std::make_shared(input.get_element_type(), input_tensor_shape); - FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); + tensor_ptr = std::make_shared(input_info.type, input_tensor_shape); + FillInputBlob(tensor_ptr, batch_slice_idx, input_info.name, context, subgraph_context_); } try { - infer_request->SetTensor(std::move(input_name), tensor_ptr); + infer_request->SetTensor(input_info.name, tensor_ptr); } catch (const char* msg) { ORT_THROW(msg); } } else { - if ((session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { + if (cpu_or_gpu) { OVTensorPtr graph_input_blob; try { - graph_input_blob = infer_request->GetTensor(input_name); + graph_input_blob = infer_request->GetTensor(input_info.name); } catch (const char* msg) { ORT_THROW(msg); } - FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_); + FillInputBlob(std::move(graph_input_blob), batch_slice_idx, input_info.name, context, subgraph_context_); } else { - auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); - ort_tensor_key_t ort_tensor_key{input_name}; + auto tensor = context.GetInput(input_info.onnx_index); + ort_tensor_key_t ort_tensor_key{input_info.name}; auto it = ort_ov_tensor_map.find(ort_tensor_key); - if ((it == ort_ov_tensor_map.end()) || - (it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) { + if ((it == ort_ov_tensor_map.end()) || it->second.ort_ptr != tensor.GetTensorRawData()) { ov_tensor_data_t ov_tensor_data; - const auto& input = ov_input_info.at(input_idx); - ov_tensor_data.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape(), + ov_tensor_data.tensor_ptr = std::make_shared(input_info.type, input_info.ov_shape.get_shape(), const_cast(tensor.GetTensorRawData())); ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; try { - infer_request->SetTensor(std::move(input_name), ov_tensor_data.tensor_ptr); + infer_request->SetTensor(input_info.name, ov_tensor_data.tensor_ptr); } catch (const char* msg) { ORT_THROW(msg); } } } } - } // Loop subgraph original input names + } // Loop subgraph original input - if (session_context_.device_type.find("NPU") != std::string::npos) { + if (npu) { // Set the output blob as remote blob - auto graph_output_info = exe_network_.Get().outputs(); - auto output_idx = 0; - for (auto output_info_iter = graph_output_info.begin(); - output_info_iter != graph_output_info.end(); ++output_info_iter) { - auto output_names = output_info_iter->get_names(); - std::string onnx_output_name; - std::string output_name; - // using the output name retrieved from ONNX original to match with the output names returned by OV tensors - for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { - onnx_output_name = it->first; - if (output_names.find(onnx_output_name) != output_names.end()) { - // Assigning the output_name - output_name = it->first; - break; - } - } - size_t batch_size = 1; - Ort::UnownedValue tensor = GetOutputTensor(context, - batch_size, - infer_request, - output_name, - subgraph_context_.output_names); - ort_tensor_key_t ort_tensor_key{output_name}; + for (const auto& output_info : bindings_->network_outputs_) { + Ort::UnownedValue tensor = context.GetOutput(output_info.onnx_index, output_info.onnx_shape); + + ort_tensor_key_t ort_tensor_key{output_info.name}; const auto& it = ort_ov_tensor_map.find(ort_tensor_key); - if ((it == ort_ov_tensor_map.end()) || - (it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) { + if ((it == ort_ov_tensor_map.end()) || (it->second.ort_ptr != tensor.GetTensorRawData())) { ov_tensor_data_t ov_tensor_data; - const auto& output = graph_output_info.at(output_idx); ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); - ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape(), + ov_tensor_data.tensor_ptr = std::make_shared(output_info.type, output_info.ov_shape.get_shape(), const_cast(tensor.GetTensorRawData())); ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; try { - infer_request->SetTensor(std::move(output_name), ov_tensor_data.tensor_ptr); + infer_request->SetTensor(output_info.name, ov_tensor_data.tensor_ptr); } catch (const char* msg) { ORT_THROW(msg); } } - output_idx++; } } @@ -611,44 +573,22 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { // Wait for Async inference completion try { + bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos || + session_context_.device_type.find("GPU") != std::string::npos; + infer_request->WaitRequest(); - auto graph_output_info = exe_network_.Get().outputs(); - for (auto output_info_iter = graph_output_info.begin(); - output_info_iter != graph_output_info.end(); ++output_info_iter) { - OVTensorPtr graph_output_blob; - auto output_names = output_info_iter->get_names(); - std::string onnx_output_name; - std::string output_name; - bool output_name_found = false; - // using the output name retrieved from ONNX original to match with the output names returned by OV tensors - for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { - onnx_output_name = it->first; - if (output_names.find(onnx_output_name) != output_names.end()) { - // Assigning the output_name - output_name = it->first; - output_name_found = true; - break; - } - } - if (!output_name_found) { - ORT_THROW( - log_tag + - "Output names mismatch between OpenVINO and ONNX. " - "[ONNX Output: ] " + - onnx_output_name + - " doesn't exist in the " - "list of OpenVINO output tensor names"); - } - if ((session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { + + if (cpu_or_gpu) { + for (const auto& output_info : bindings_->network_outputs_) { + OVTensorPtr graph_output_blob; try { - graph_output_blob = infer_request->GetTensor(output_name); + graph_output_blob = infer_request->GetTensor(output_info.name); } catch (const char* msg) { ORT_THROW(msg); } size_t batch_size = 1; Ort::UnownedValue output_tensor = - GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); + GetOutputTensor(context, batch_size, infer_request, output_info.name, subgraph_context_.output_names); auto mem_info = output_tensor.GetTensorMemoryInfo(); if (mem_info.GetAllocatorName() == OpenVINO_GPU) { return; diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 7d905f4a1e2f7..230d3cb5db34a 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -13,11 +13,14 @@ #include #include #include +#include +#include #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/openvino/contexts.h" #include "core/providers/openvino/ibackend.h" #include "core/providers/openvino/ov_interface.h" +#include "core/providers/openvino/backend_utils.h" namespace onnxruntime { namespace openvino_ep { @@ -27,6 +30,47 @@ struct ov_tensor_data_t { const void* ort_ptr; }; +struct OnnxToOvNetworkBindings { + struct ParameterInfo { + std::string name; + uint32_t ov_index; + uint32_t onnx_index; + ov::element::Type type; + ov::PartialShape ov_shape; + std::vector onnx_shape; + }; + std::vector network_outputs_; + std::vector network_inputs_; + + OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context) { + auto populate = [&](auto& input_output_map, const SubGraphContext::string_index_map_t& onnx_input_map, const auto& ov_parameters) { + for (const auto& [onnx_name, onnx_param_index] : onnx_input_map) { + auto it = std::find_if(ov_parameters.begin(), ov_parameters.end(), + [&onnx_name](const auto& ov_parameter_info) { return ov_parameter_info.get_names().contains(onnx_name); }); + + ORT_ENFORCE(it != ov_parameters.end(), backend_utils::log_tag, + "Input names mismatch between OpenVINO and ONNX. ", onnx_name, + " doesn't exist in the list of OpenVINO input tensor names"); + + auto ov_param_index = std::distance(ov_parameters.begin(), it); + + auto shape = ov_parameters[ov_param_index].get_partial_shape(); + auto type = ov_parameters[ov_param_index].get_element_type(); + ParameterInfo info{onnx_name, ov_param_index, onnx_param_index, type, shape}; + + if (shape.is_static()) { + auto static_shape = shape.get_shape(); + std::transform(static_shape.begin(), static_shape.end(), std::back_inserter(info.onnx_shape), [](const auto& dim) { return static_cast(dim); }); + } + input_output_map.push_back(std::move(info)); + } + }; + + populate(network_inputs_, subgraph_context.input_names, exec_network.Get().inputs()); + populate(network_outputs_, subgraph_context.output_names, exec_network.Get().outputs()); + } +}; + class InferRequestsQueue; class BasicBackend : public IBackend { public: @@ -43,7 +87,6 @@ class BasicBackend : public IBackend { } private: - void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&); bool ValidateSubgraph(std::map>& const_outputs_map); void PopulateConfigValue(ov::AnyMap& device_config); void EnableCaching(); @@ -71,6 +114,7 @@ class BasicBackend : public IBackend { using ort_tensor_key_t = const std::string; std::map ort_ov_tensor_map; + std::unique_ptr bindings_; }; class InferRequestsQueue { From b602f472cfe11d753a797418b684a2b22c5bac68 Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Mon, 26 May 2025 11:05:38 -0700 Subject: [PATCH 041/138] Support workload type for dynamic shaped models (#690) --- .../core/providers/openvino/backend_manager.cc | 7 ++++--- .../core/providers/openvino/backend_manager.h | 2 +- .../providers/openvino/backends/basic_backend.h | 2 +- onnxruntime/core/providers/openvino/ibackend.h | 2 +- .../openvino/openvino_execution_provider.cc | 16 +++++++++++++--- 5 files changed, 20 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index cf8e11826ce8b..b7e6245b1834f 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -28,9 +28,10 @@ SessionContext& BackendManager::GetSessionContext() { return session_context_; } -ov::CompiledModel& BackendManager::GetOVCompiledModel() { - ov::CompiledModel& ov_ptr = concrete_backend_->GetOVCompiledModel(); - return (ov_ptr); +ov::CompiledModel BackendManager::GetOVCompiledModel() { + if (concrete_backend_) + return concrete_backend_->GetOVCompiledModel(); + return ov::CompiledModel(); } BackendManager::BackendManager(SessionContext& session_context, diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index cdc27701ec2e6..cb1ca7001a00c 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -29,7 +29,7 @@ class BackendManager { void ShutdownBackendManager(); SessionContext& GetSessionContext(); Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph); - ov::CompiledModel& GetOVCompiledModel(); + ov::CompiledModel GetOVCompiledModel(); private: std::unique_ptr GetModelProtoFromFusedNode( diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 230d3cb5db34a..130699abd465b 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -82,7 +82,7 @@ class BasicBackend : public IBackend { void Infer(OrtKernelContext* context) override; ~BasicBackend() override = default; - ov::CompiledModel& GetOVCompiledModel() override { + ov::CompiledModel GetOVCompiledModel() override { return exe_network_.Get(); } diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h index 04d1f52cbf834..4532349897d17 100644 --- a/onnxruntime/core/providers/openvino/ibackend.h +++ b/onnxruntime/core/providers/openvino/ibackend.h @@ -15,7 +15,7 @@ namespace openvino_ep { class IBackend { public: virtual void Infer(OrtKernelContext* context) = 0; - virtual ov::CompiledModel& GetOVCompiledModel() = 0; + virtual ov::CompiledModel GetOVCompiledModel() = 0; virtual ~IBackend() = default; }; using ptr_stream_t = std::unique_ptr; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 70bc64c0f65bc..3793317749a04 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -238,10 +238,20 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span Date: Thu, 5 Jun 2025 15:14:06 +0530 Subject: [PATCH 042/138] Sahar/tdr failure (#698) * Catch exception with TDR * Handle exceptions during parallel execution with OVEP * Remove IO Buffer Implementation --------- Co-authored-by: TejalKhade28 Co-authored-by: Preetha Veeramalai --- cmake/onnxruntime_providers_openvino.cmake | 5 - .../openvino/backends/basic_backend.cc | 211 +++--------------- .../openvino/backends/basic_backend.h | 22 +- .../core/providers/openvino/ov_interface.cc | 42 +--- .../core/providers/openvino/ov_interface.h | 20 +- .../test/perftest/performance_runner.cc | 9 +- 6 files changed, 64 insertions(+), 245 deletions(-) diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index 03f67983c70ab..d7cb2d5ea0d0f 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -30,11 +30,6 @@ endif() list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime ${PYTHON_LIBRARIES}) - if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS}) AND onnxruntime_USE_OPENVINO_GPU) - add_definitions(-DIO_BUFFER_ENABLED=1) - list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS}) - endif() - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_openvino_cc_srcs}) onnxruntime_add_shared_library_module(onnxruntime_providers_openvino ${onnxruntime_providers_openvino_cc_srcs} "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc") diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index e77ff973f3a87..dedb6da1bae58 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -62,25 +62,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr try { // IO_BUFFER is enabled on GPU HW. // Pre-requisite is provider_option "context" must be set -#if defined(IO_BUFFER_ENABLED) - cl_context ctx = static_cast(session_context_.context); - remote_context_ = new ov::intel_gpu::ocl::ClContext(OVCore::Get()->core, ctx); - if (subgraph_context_.is_ep_ctx_graph) { - exe_network_ = OVCore::Get()->ImportModel(*model_stream, - remote_context_, - subgraph_context_.subgraph_name); - model_stream.reset(); // Delete stream after it is no longer needed - } else { - std::string model = model_proto->SerializeAsString(); - if (!subgraph_context.has_dynamic_input_shape) { - model_proto.reset() - } - auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); - LOGS_DEFAULT(INFO) << log_tag << "IO Buffering Enabled"; - exe_network_ = OVCore::Get()->CompileModel( - ov_model, remote_context_, subgraph_context_.subgraph_name); - } -#else // !IO_BUFFER_ENABLED auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) || (session_context_.OpenVINO_Version.at(0) >= 2024 && session_context_.OpenVINO_Version.at(1) > 2)); @@ -117,7 +98,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr exe_network_ = OVCore::Get()->CompileModel( ov_model, hw_target, device_config, subgraph_context_.subgraph_name); } -#endif LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } catch (const char* msg) { ORT_THROW(msg); @@ -459,150 +439,46 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque } } -#ifdef IO_BUFFER_ENABLED -// Wait for Remote Aynchronous inference completion -void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { - try { - auto graph_input_info = exe_network_.Get().inputs(); - int input_idx = 0; - for (auto input_info_iter = graph_input_info.begin(); - input_info_iter != graph_input_info.end(); ++input_info_iter) { - auto input_names = input_info_iter->get_names(); - std::string onnx_input_name; - std::string input_name; - // use names retrieved from original ONNX model to assign the right onnx input name for the graph - for (auto it = subgraph_context_.input_names.begin(); it != subgraph_context_.input_names.end(); ++it) { - if (it->second == input_idx) { - onnx_input_name = it->first; - break; - } - } - // using the input name retrieved from ONNX original to match with the input names returned by OV tensors - if (input_names.find(onnx_input_name) != input_names.end()) { - input_name = onnx_input_name; - } else { - ORT_THROW(log_tag + - "Input names mismatch between OpenVINO and ONNX. " + - onnx_input_name + - " doesn't exist in the list of OpenVINO input tensor names"); - } - input_idx++; - // Kernel Context Input Buffer - const auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); - // If the ORTValue wraps a device pointer - auto mem_info = tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - // Get the shared buffer pointer - const void* tensor_data = tensor.GetTensorRawData(); - const cl::Buffer* shared_buffer_const = static_cast(tensor_data); - // Create an Input Remote Blob - auto input = graph_input_info.at(0); - auto remote_blob = remote_context_->create_tensor( - input.get_element_type(), input.get_shape(), *shared_buffer_const); - ov::Tensor tensor_remote = static_cast(remote_blob); - OVTensorPtr tensor_ptr = std::make_shared(tensor_remote); - infer_request->SetTensor(input_name, tensor_ptr); - } else { - OVTensorPtr graph_input_blob; - graph_input_blob = infer_request->GetTensor(input_name); - size_t batch_slice_idx = 0; - FillInputBlob(graph_input_blob, batch_slice_idx, input_name, context, subgraph_context_); - } - } - - // Set the output blob as remote blob - auto graph_output_info = exe_network_.Get().outputs(); - for (auto output_info_iter = graph_output_info.begin(); - output_info_iter != graph_output_info.end(); ++output_info_iter) { - auto output_names = output_info_iter->get_names(); - std::string onnx_output_name; - std::string output_name; - bool output_name_found = false; - // using the output name retrieved from ONNX original to match with the output names returned by OV tensors - for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { - onnx_output_name = it->first; - if (output_names.find(onnx_output_name) != output_names.end()) { - // Assigning the output_name - output_name = it->first; - output_name_found = true; - break; - } - } - if (!output_name_found) { - ORT_THROW( - log_tag + - "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + - onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); - } - - size_t batch_size = 1; - Ort::UnownedValue tensor = GetOutputTensor(context, - batch_size, - infer_request, - output_name, - subgraph_context_.output_names); - auto mem_info = tensor.GetTensorMemoryInfo(); - // Check if ORT Value wraps a device pointer - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - const void* tensor_data = tensor.GetTensorRawData(); - const cl::Buffer* shared_buffer_const = static_cast(tensor_data); - // Create a shared Blob, set the Infer Request Output Blob - auto output = graph_output_info.at(0); - auto remote_tensor = - remote_context_->create_tensor(output.get_element_type(), output.get_shape(), *shared_buffer_const); - ov::Tensor tensor_t = static_cast(remote_tensor); - OVTensorPtr tensor_ptr = std::make_shared(tensor_t); - try { - infer_request->SetTensor(output_name, tensor_ptr); - } catch (const char* msg) { - ORT_THROW(msg); - } - } - } - - // Start Async inference - infer_request->StartAsync(); - } catch (const char* msg) { - ORT_THROW(msg); - } -} -#endif - // Wait for asynchronous inference completion on an Infer Request object indexed by infer_req_idx // and copy the results into a slice location within the batched output buffer indexed by batch_slice_idx void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { // Wait for Async inference completion try { - bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos; - infer_request->WaitRequest(); + } catch(const std::runtime_error& e) { + infer_request->CancelRequest(); + inferRequestsQueue_->deleteRequest(); + ORT_THROW(log_tag + e.what()); + } - if (cpu_or_gpu) { - for (const auto& output_info : bindings_->network_outputs_) { - OVTensorPtr graph_output_blob; - try { - graph_output_blob = infer_request->GetTensor(output_info.name); - } catch (const char* msg) { - ORT_THROW(msg); - } - size_t batch_size = 1; - Ort::UnownedValue output_tensor = - GetOutputTensor(context, batch_size, infer_request, output_info.name, subgraph_context_.output_names); - auto mem_info = output_tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos || + session_context_.device_type.find("GPU") != std::string::npos; + if (cpu_or_gpu) { + for (const auto& output_info : bindings_->network_outputs_) { + OVTensorPtr graph_output_blob; + try { + graph_output_blob = infer_request->GetTensor(output_info.name); + } catch (const char* msg) { + ORT_THROW(msg); + } + size_t batch_size = 1; + Ort::UnownedValue output_tensor = + GetOutputTensor(context, batch_size, infer_request, output_info.name, subgraph_context_.output_names); + auto mem_info = output_tensor.GetTensorMemoryInfo(); + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { return; - } else { - size_t batch_slice = 0; - FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); - } + } else { + size_t batch_slice = 0; + FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); } } + } - if (!const_outputs_map_.empty()) { - for (const auto& item : const_outputs_map_) { - const auto& out_name = item.first; - auto node = item.second; + if (!const_outputs_map_.empty()) { + for (const auto& item : const_outputs_map_) { + const auto& out_name = item.first; + auto node = item.second; + try { Ort::UnownedValue output_tensor = GetOutputTensor(context, out_name, subgraph_context_.output_names, @@ -613,10 +489,10 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe } else { FillOutputsWithConstantData(std::move(node), output_tensor); } + } catch (std::string const& msg) { + ORT_THROW(msg); } } - } catch (const char* msg) { - ORT_THROW(msg); } } @@ -650,31 +526,20 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { } } else { - // Requesting for an idle infer_request from a pool of infer_requests_ OVInferRequestPtr infer_request; infer_request = inferRequestsQueue_->getIdleRequest(); -#ifdef IO_BUFFER_ENABLED - if ((session_context_.device_type.find("GPU") != std::string::npos) && - (session_context_.context != nullptr) && session_context_.is_wholly_supported_graph) { - try { - StartRemoteAsyncInference(context, infer_request); - } catch (std::string const& msg) { - ORT_THROW(msg); - } - } else { - try { - StartAsyncInference(context, infer_request); - } catch (std::string const& msg) { - ORT_THROW(msg); - } + if(infer_request == nullptr) { + ORT_THROW("OpenVINO Execution Provider :: There are no inference requests"); + LOGS_DEFAULT(FATAL) << log_tag << "Create Infer Requests do not exist"; + return; } -#else + + LOGS_DEFAULT(INFO) << log_tag << "Get Idle Request"; try { StartAsyncInference(context, infer_request); } catch (const std::runtime_error& e) { ORT_THROW(log_tag + " Exception at StartAsyncInference: " + e.what()); } -#endif try { CompleteAsyncInference(context, infer_request); } catch (const std::runtime_error& e) { @@ -696,13 +561,11 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { // Once the inference is completed, the infer_request becomes free and is placed back into pool of infer_requests_ inferRequestsQueue_->putIdleRequest(std::move(infer_request)); #ifndef NDEBUG -#ifndef IO_BUFFER_ENABLED // Printing performance counts is disabled when IO_BUFFER_ENABLED if (openvino_ep::backend_utils::IsDebugEnabled()) { inferRequestsQueue_->printstatus(); // Printing the elements of infer_requests_ vector pool only in debug mode std::string& hw_target = session_context_.device_type; printPerformanceCounts(std::move(infer_request_), std::cout, hw_target); } -#endif #endif } } diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 130699abd465b..697c088a80620 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -94,11 +94,6 @@ class BasicBackend : public IBackend { void EnableStreams(); void SetNumThreads(ov::AnyMap& device_config); void StartAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); - -#ifdef IO_BUFFER_ENABLED - void StartRemoteAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); -#endif - void CompleteAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); SessionContext& session_context_; @@ -108,10 +103,6 @@ class BasicBackend : public IBackend { OVExeNetwork exe_network_; std::map> const_outputs_map_; std::unique_ptr inferRequestsQueue_; -#if defined IO_BUFFER_ENABLED - OVRemoteContextPtr remote_context_; -#endif - using ort_tensor_key_t = const std::string; std::map ort_ov_tensor_map; std::unique_ptr bindings_; @@ -121,6 +112,7 @@ class InferRequestsQueue { public: InferRequestsQueue(OVExeNetwork& net, size_t nireq, std::function initializer) { OVInferRequestPtr infer_request; + live_threads=nireq; for (size_t id = 0; id < nireq; id++) { infer_request = std::make_shared(net.CreateInferRequest()); initializer(infer_request); @@ -152,16 +144,28 @@ class InferRequestsQueue { OVInferRequestPtr getIdleRequest() { std::unique_lock lock(_mutex); + std::cout << "get Idle Request" << live_threads << "\n"; + if(live_threads==0) { + return nullptr; + } + _cv.wait(lock, [this] { return infer_requests_.size() > 0; }); auto request = infer_requests_.at(0); infer_requests_.erase(infer_requests_.begin()); return request; } + void deleteRequest() { + std::unique_lock lock(_mutex); + live_threads=live_threads-1; + std::cout << "delete Request" << live_threads << "\n"; + } + private: std::mutex _mutex; std::condition_variable _cv; std::vector infer_requests_; + int live_threads; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index a175ca863d1d1..0024a5e121bbf 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -143,38 +143,6 @@ void OVCore::SetCache(const std::string& cache_dir_path) { core.set_property(ov::cache_dir(cache_dir_path)); } -#ifdef IO_BUFFER_ENABLED -OVExeNetwork OVCore::CompileModel(std::shared_ptr& model, - OVRemoteContextPtr context, std::string name) { - try { - auto obj = core.compile_model(model, *context); -#ifndef NDEBUG - printDebugInfo(obj); -#endif - return OVExeNetwork(obj); - } catch (const Exception& e) { - ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Exception while Loading Network for graph " + name); - } -} -OVExeNetwork OVCore::ImportModel(std::shared_ptr model_stream, - OVRemoteContextPtr context, std::string name) { - try { - auto obj = core.import_model(*model_stream, *context); -#ifndef NDEBUG - printDebugInfo(obj); -#endif - OVExeNetwork exe(obj); - return exe; - } catch (const Exception& e) { - ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Exception while Loading Network for graph " + name); - } -} -#endif - std::vector OVCore::GetAvailableDevices() const { std::vector available_devices = core.get_available_devices(); return available_devices; @@ -294,12 +262,16 @@ void OVInferRequest::Infer() { } void OVInferRequest::WaitRequest() { + ovInfReq.wait(); +} + +void OVInferRequest::CancelRequest() { try { - ovInfReq.wait(); + ovInfReq.cancel(); } catch (const Exception& e) { - ORT_THROW(log_tag + " Wait Model Failed: " + e.what()); + ORT_THROW(log_tag + " Cancel Model Failed: " + e.what()); } catch (...) { - ORT_THROW(log_tag + " Wait Mode Failed"); + ORT_THROW(log_tag + " Cancel Mode Failed"); } } diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index bebe73bd702dd..866f4a02f7780 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -14,11 +14,6 @@ #include "openvino/runtime/intel_npu/properties.hpp" #include "openvino/pass/convert_fp32_to_fp16.hpp" #include "openvino/frontend/manager.hpp" - -#ifdef IO_BUFFER_ENABLED -#include -#endif - #include namespace onnxruntime { @@ -32,12 +27,6 @@ typedef ov::ProfilingInfo OVProfilingInfo; typedef ov::Model OVNetwork; typedef std::shared_ptr OVInferRequestPtr; typedef std::shared_ptr OVTensorPtr; - -#ifdef IO_BUFFER_ENABLED -typedef ov::intel_gpu::ocl::ClContext* OVRemoteContextPtr; -typedef ov::RemoteContext OVRemoteContext; -#endif - std::optional queryOVProperty(const std::string& property, const std::string& device_type); template @@ -87,14 +76,6 @@ struct OVCore : WeakSingleton { std::string hw_target, const ov::AnyMap& device_config, std::string name); -#ifdef IO_BUFFER_ENABLED - OVExeNetwork CompileModel(std::shared_ptr& model, - OVRemoteContextPtr context, - std::string name); - OVExeNetwork ImportModel(std::shared_ptr model_stream, - OVRemoteContextPtr context, - std::string name); -#endif std::vector GetAvailableDevices() const; std::vector GetAvailableDevices(const std::string& device_type) const; void SetCache(const std::string& cache_dir_path); @@ -122,6 +103,7 @@ class OVInferRequest { void StartAsync(); void Infer(); void WaitRequest(); + void CancelRequest(); void QueryStatus(); explicit OVInferRequest(ov::InferRequest obj) : ovInfReq(std::move(obj)) {} OVInferRequest() : ovInfReq(ov::InferRequest()) {} diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc index faf0c34193717..8ec9694227c14 100644 --- a/onnxruntime/test/perftest/performance_runner.cc +++ b/onnxruntime/test/perftest/performance_runner.cc @@ -203,8 +203,9 @@ Status PerformanceRunner::RunParallelDuration() { counter++; tpool->Schedule([this, &counter, &m, &cv]() { auto status = RunOneIteration(); - if (!status.IsOK()) + if (!status.IsOK()) { std::cerr << status.ErrorMessage(); + } // Simplified version of Eigen::Barrier std::lock_guard lg(m); counter--; @@ -216,8 +217,10 @@ Status PerformanceRunner::RunParallelDuration() { } while (duration_seconds.count() < performance_test_config_.run_config.duration_in_seconds); // Join - std::unique_lock lock(m); - cv.wait(lock, [&counter]() { return counter == 0; }); + tpool->Schedule([this, &counter, &m, &cv]() { + std::unique_lock lock(m); + cv.wait(lock, [&counter]() { return counter == 0; }); + }); return Status::OK(); } From 660adfcf10aebda1a5416f46bb5eda4084b4b367 Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Thu, 5 Jun 2025 18:30:48 +0530 Subject: [PATCH 043/138] feat: ORT GenAI Stateful Compilation changes (#676) * feat: ORT GenAI Stateful Compilation changes * fix: Disabled UT as testdata/attention_past_state.onnx model is invalid * fix:lint fixes * fix: refactor tensor caching * update: Fix optional position ids caching logic * fix: remove unwanted comment --- .../providers/openvino/backend_manager.cc | 19 +- .../core/providers/openvino/backend_manager.h | 1 + .../openvino/backends/basic_backend.cc | 67 +++- .../openvino/backends/basic_backend.h | 16 +- .../core/providers/openvino/contexts.h | 1 + .../core/providers/openvino/ibackend.h | 1 + .../openvino/openvino_execution_provider.cc | 19 + .../openvino/openvino_provider_factory.cc | 9 +- .../core/providers/openvino/ov_interface.cc | 260 ++++++++++++- .../core/providers/openvino/ov_interface.h | 50 ++- .../openvino/ov_stateful_patch_utils.cc | 350 ++++++++++++++++++ .../openvino/ov_stateful_patch_utils.h | 84 +++++ .../test/contrib_ops/attention_op_test.cc | 2 +- onnxruntime/test/perftest/ort_test_session.cc | 12 +- 14 files changed, 839 insertions(+), 52 deletions(-) create mode 100644 onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc create mode 100644 onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index b7e6245b1834f..c22f2e9cc0fa1 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -44,6 +44,10 @@ BackendManager::BackendManager(SessionContext& session_context, shared_context_{shared_context} { subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph); + bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos || + session_context_.device_type.find("GPU") != std::string::npos; + bool npu = session_context_.device_type.find("NPU") != std::string::npos; + subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) { // return empty if graph has no inputs or if types are not one of FP32/FP16 // else assume the type of the first input @@ -105,8 +109,7 @@ BackendManager::BackendManager(SessionContext& session_context, if (ModelHasSymbolicInputDims(subgraph)) { subgraph_context_.has_dynamic_input_shape = true; LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; - if ((session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos) && + if (cpu_or_gpu || (npu && session_context_.enable_causallm) && !session_context_.disable_dynamic_shapes) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; @@ -480,6 +483,9 @@ BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_p void BackendManager::Compute(OrtKernelContext* context) { Ort::KernelContext ctx(context); std::chrono::high_resolution_clock::time_point start_compute, end_compute; + bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos || + session_context_.device_type.find("GPU") != std::string::npos; + bool npu = session_context_.device_type.find("NPU") != std::string::npos; #ifdef OPENVINO_FIL_ENABLED static bool fil_enabled = true; if (fil_enabled) { @@ -493,8 +499,7 @@ void BackendManager::Compute(OrtKernelContext* context) { // disable_dynamic_shapes is always set to true for OV NPU plugin. if (subgraph_context_.has_dynamic_input_shape && !session_context_.disable_dynamic_shapes && - (session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { + (cpu_or_gpu || (npu && session_context_.enable_causallm))) { concrete_backend_->Infer(context); } else if (subgraph_context_.has_dynamic_input_shape) { std::vector> tensor_shapes = GetInputTensorShapes(ctx); @@ -567,5 +572,11 @@ void BackendManager::ShutdownBackendManager() { concrete_backend_.reset(); } +void BackendManager::RewindKVCache(size_t index) { + if (concrete_backend_) { + concrete_backend_->RewindKVCache(index); + } +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index cb1ca7001a00c..799dc50dd7a63 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -30,6 +30,7 @@ class BackendManager { SessionContext& GetSessionContext(); Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph); ov::CompiledModel GetOVCompiledModel(); + void RewindKVCache(size_t index); private: std::unique_ptr GetModelProtoFromFusedNode( diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index dedb6da1bae58..7902b3edb2276 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -15,6 +15,7 @@ #include "core/providers/openvino/backends/basic_backend.h" #include "core/providers/openvino/onnx_ctx_model_helper.h" #include "core/providers/openvino/backend_manager.h" +#include "core/providers/openvino/ov_stateful_patch_utils.h" namespace onnxruntime { @@ -29,6 +30,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr ptr_stream_t& model_stream) : session_context_{session_context}, subgraph_context_{subgraph_context}, shared_context_{shared_context} { std::string& hw_target = session_context_.device_type; + bool enable_causallm = session_context_.enable_causallm; if (ValidateSubgraph(const_outputs_map_)) return; @@ -43,7 +45,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr // Setting OpenCL queue throttling for GPU EnableGPUThrottling(device_config); - // Enable streams; default=1 unless ovverriden by user config + // Enable streams; default=1 unless overridden by user configuration EnableStreams(); // Set the inference_num_threads property of the CPU @@ -76,7 +78,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr } else if (!session_context_.has_external_weights && !subgraph_context_.has_dynamic_input_shape && !session_context_.so_context_enable && - auto_unified_compile) { + !enable_causallm && auto_unified_compile) { // Unified OV compile_model is efficient when ov model caching is enabled // Unified OV compile_model API is supported with AUTO from version 2024.3 and above // Inputs with static dimensions @@ -96,7 +98,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr } auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); exe_network_ = OVCore::Get()->CompileModel( - ov_model, hw_target, device_config, subgraph_context_.subgraph_name); + ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name); } LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } catch (const char* msg) { @@ -120,7 +122,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr }; } inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, num_infer_req, std::move(initializer))); - bindings_ = std::make_unique(exe_network_, subgraph_context_); + bindings_ = std::make_unique(exe_network_, subgraph_context_, session_context_); } bool BasicBackend::ValidateSubgraph(std::map>& const_outputs_map) { @@ -181,6 +183,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { if (!session_context_.load_config.empty()) { const std::map& target_config = session_context_.load_config; + if ((session_context_.device_type.find("NPU") != std::string::npos) && session_context_.enable_causallm) { + if (target_config.find("NPU") != target_config.end()) { + auto npu_genai_config = target_config.at("NPU"); + CausalLMConfig().ApplyConfig(npu_genai_config, device_config); + } else { + LOGS_DEFAULT(WARNING) << "ORT GenAI CausalLMConfig Configuration not found."; + } + } + if (session_context_.device_type.find("NPU") != std::string::npos) { auto npuw_config = target_config.at("NPU"); @@ -246,7 +257,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options, const std::vector& supported_properties) { for (const auto& [key, value] : config_options) { - if (key.find("NPUW") != std::string::npos) { + if ((key.find("NPUW") != std::string::npos) || + ((device_config.find(key) != device_config.end()) && session_context_.enable_causallm)) { continue; } if (is_supported_and_mutable(key, supported_properties)) { @@ -339,6 +351,13 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) { device_config.emplace(ov::inference_num_threads(session_context_.num_of_threads)); } +void BasicBackend::RewindKVCache(size_t index) { + OVInferRequestPtr infer_request; + infer_request = inferRequestsQueue_->getIdleRequest(); + infer_request->RewindKVCache(index); + inferRequestsQueue_->putIdleRequest(std::move(infer_request)); +} + // Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on // an Infer Request indexed by infer_req_idx void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { @@ -351,7 +370,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque size_t batch_slice_idx = 0; if (subgraph_context_.has_dynamic_input_shape && !session_context_.disable_dynamic_shapes && - cpu_or_gpu) { + cpu_or_gpu || (npu && session_context_.enable_causallm)) { auto tensor = context.GetInput(input_info.onnx_index); auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); auto tensor_shape = tensor_info.GetShape(); @@ -409,7 +428,8 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque } } // Loop subgraph original input - if (npu) { + // For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path for NPU plugin as well. + if (npu && !session_context_.enable_causallm) { // Set the output blob as remote blob for (const auto& output_info : bindings_->network_outputs_) { Ort::UnownedValue tensor = context.GetOutput(output_info.onnx_index, output_info.onnx_shape); @@ -453,19 +473,20 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos || session_context_.device_type.find("GPU") != std::string::npos; - if (cpu_or_gpu) { + bool npu = session_context_.device_type.find("NPU") != std::string::npos; + if (cpu_or_gpu || (npu && session_context_.enable_causallm)) { for (const auto& output_info : bindings_->network_outputs_) { - OVTensorPtr graph_output_blob; - try { - graph_output_blob = infer_request->GetTensor(output_info.name); - } catch (const char* msg) { - ORT_THROW(msg); - } - size_t batch_size = 1; - Ort::UnownedValue output_tensor = - GetOutputTensor(context, batch_size, infer_request, output_info.name, subgraph_context_.output_names); - auto mem_info = output_tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + OVTensorPtr graph_output_blob; + try { + graph_output_blob = infer_request->GetTensor(output_info.name); + } catch (const char* msg) { + ORT_THROW(msg); + } + size_t batch_size = 1; + Ort::UnownedValue output_tensor = + GetOutputTensor(context, batch_size, infer_request, output_info.name, subgraph_context_.output_names); + auto mem_info = output_tensor.GetTensorMemoryInfo(); + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { return; } else { size_t batch_slice = 0; @@ -538,11 +559,19 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { try { StartAsyncInference(context, infer_request); } catch (const std::runtime_error& e) { + // If the inference fails (exception from ov::InferRequest::infer()), + // we need to put the infer_request back into the pool to avoid deadlocks + // and to allow the next inference request to proceed. + inferRequestsQueue_->putIdleRequest(std::move(infer_request)); ORT_THROW(log_tag + " Exception at StartAsyncInference: " + e.what()); } try { CompleteAsyncInference(context, infer_request); } catch (const std::runtime_error& e) { + // If the inference fails (exception from ov::InferRequest::infer()), + // we need to put the infer_request back into the pool to avoid deadlocks + // and to allow the next inference request to proceed. + inferRequestsQueue_->putIdleRequest(std::move(infer_request)); ORT_THROW(log_tag + " Exception at CompleteAsyncInference: " + e.what()); } diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 697c088a80620..fe178ccb5661b 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -42,12 +42,22 @@ struct OnnxToOvNetworkBindings { std::vector network_outputs_; std::vector network_inputs_; - OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context) { + OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context, SessionContext& session_context) { auto populate = [&](auto& input_output_map, const SubGraphContext::string_index_map_t& onnx_input_map, const auto& ov_parameters) { for (const auto& [onnx_name, onnx_param_index] : onnx_input_map) { auto it = std::find_if(ov_parameters.begin(), ov_parameters.end(), [&onnx_name](const auto& ov_parameter_info) { return ov_parameter_info.get_names().contains(onnx_name); }); + // For Stateful Model Compilation, the ONNX model includes KV cache (past/present) tensors. + // However, these tensors are internally converted to a stateful representation, which removes them. + // To prevent runtime exceptions, we simply continue processing here. + if ((onnx_name.empty() || onnx_name == "beam_idx" || + onnx_name.find("past_key_values") != std::string::npos || + onnx_name.find("present") != std::string::npos) && + session_context.enable_causallm) { + continue; + } + ORT_ENFORCE(it != ov_parameters.end(), backend_utils::log_tag, "Input names mismatch between OpenVINO and ONNX. ", onnx_name, " doesn't exist in the list of OpenVINO input tensor names"); @@ -85,6 +95,7 @@ class BasicBackend : public IBackend { ov::CompiledModel GetOVCompiledModel() override { return exe_network_.Get(); } + void RewindKVCache(size_t index) override; private: bool ValidateSubgraph(std::map>& const_outputs_map); @@ -114,7 +125,7 @@ class InferRequestsQueue { OVInferRequestPtr infer_request; live_threads=nireq; for (size_t id = 0; id < nireq; id++) { - infer_request = std::make_shared(net.CreateInferRequest()); + infer_request = net.CreateInferRequest(); initializer(infer_request); infer_requests_.push_back(infer_request); } @@ -144,7 +155,6 @@ class InferRequestsQueue { OVInferRequestPtr getIdleRequest() { std::unique_lock lock(_mutex); - std::cout << "get Idle Request" << live_threads << "\n"; if(live_threads==0) { return nullptr; } diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 7560f4570bd32..2506d587dd3ad 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -97,6 +97,7 @@ struct ProviderInfo { bool disable_dynamic_shapes{false}; // [disable_dynamic_shapes]: Rewrite dynamic shaped models to // static shape at runtime and execute. bool enable_qdq_optimizer{false}; // Enables QDQ pruning for efficient inference latency with NPU + bool enable_causallm{false}; // Enables Causal LM Compilation for ORT GenAI OVEP Pass bool so_context_enable{false}; // ORT session option bool so_disable_cpu_ep_fallback{false}; // ORT session option bool so_context_embed_mode{false}; // ORT session option diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h index 4532349897d17..752668b3c6fbe 100644 --- a/onnxruntime/core/providers/openvino/ibackend.h +++ b/onnxruntime/core/providers/openvino/ibackend.h @@ -17,6 +17,7 @@ class IBackend { virtual void Infer(OrtKernelContext* context) = 0; virtual ov::CompiledModel GetOVCompiledModel() = 0; virtual ~IBackend() = default; + virtual void RewindKVCache(size_t index) {} }; using ptr_stream_t = std::unique_ptr; class BackendFactory { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 3793317749a04..d12f1edc57da5 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -254,6 +254,25 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span= 0) { + backend.RewindKVCache(static_cast(index)); + } else { + LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0:\t" << index; + } + } } else { // Handle unknown options LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value; diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index e5526ecd52bb9..f7e64a9be2c60 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -343,13 +343,20 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, pi.enable_qdq_optimizer = ParseBooleanOption(provider_options, "enable_qdq_optimizer"); + pi.enable_causallm = ParseBooleanOption(provider_options, "enable_causallm"); + pi.disable_dynamic_shapes = ParseBooleanOption(provider_options, "disable_dynamic_shapes"); } catch (std::string msg) { ORT_THROW(msg); } // Always true for NPU plugin or when passed . if (pi.device_type.find("NPU") != std::string::npos) { - pi.disable_dynamic_shapes = true; + // For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path. + if (pi.enable_causallm) { + pi.disable_dynamic_shapes = false; + } else { + pi.disable_dynamic_shapes = true; + } } } diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 0024a5e121bbf..0818f350562e9 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -7,6 +7,8 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/backend_utils.h" +#include "core/providers/openvino/backends/basic_backend.h" +#include "core/providers/openvino/ov_stateful_patch_utils.h" using Exception = ov::Exception; @@ -82,17 +84,85 @@ std::shared_ptr OVCore::ReadModel(std::string&& model, const std::str } } +OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr& model, + std::string& hw_target, + const ov::AnyMap& device_config) { + ov::CompiledModel compiled_model; + ov::AnyMap config = device_config; + + if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "Stateless OV Model Statistic:" << std::endl; + LogBasicModelInfo(model); + } + + LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl; + bool model_status = IsStateful(model); + LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False"); + if (!model_status) { + PatchStatefulDecoder(model); + } + + if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "Stateful OV Model Statistic:" << std::endl; + LogBasicModelInfo(model); + } + + auto kv_pos = GetKVAxesPos(model); + + if (hw_target.find("NPU") != std::string::npos) { + KVDesc kv_desc; + auto parse_genai_config = [&](const std::string& key, unsigned int default_value) { + return (config.count(key) && !config.at(key).empty() && config.at(key).as() != "0") ? config.at(key).as() : default_value; + }; + + kv_desc.max_prompt_len = parse_genai_config("MAX_PROMPT_LEN", CausalLMConfig().max_prompt_len); + kv_desc.min_response_len = parse_genai_config("MIN_RESPONSE_LEN", CausalLMConfig().min_response_len); + + // For compilation, MAX_PROMPT_LEN & MIN_RESPONSE_LEN should not be 0 + if (kv_desc.max_prompt_len == 0 || kv_desc.min_response_len == 0) { + ORT_THROW(log_tag + "MAX_PROMPT_LEN and MIN_RESPONSE_LEN cannot be 0 or empty"); + } + + if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl; + std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl; + std::cout << "kv_desc.max_prompt_len:\t" << kv_desc.max_prompt_len << std::endl; + std::cout << "kv_desc.min_response_len:\t" << kv_desc.min_response_len << std::endl; + } + + UpdateNPUConfig(config, kv_pos, kv_desc); + } else { + // This patches the OV IR model so that it only produces the logits required for sampling. + // Actually either way that happens within NPUW::LLMCompiledModel creation for NPU device, + // while this is here mostly to align this behavior for other devices viz. (CPU, GPU). + ApplySliceBeforeMatmulTransformation(model); + } + + LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow"; + compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config); + OVExeNetwork exe(compiled_model, hw_target, true); + return exe; +} + OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, + bool enable_causallm, const std::string& name) { - ov::CompiledModel obj; + OVExeNetwork exe; try { - obj = core.compile_model(ie_cnn_network, hw_target, device_config); + if (enable_causallm) { + auto mutable_model = ie_cnn_network->clone(); + exe = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config); + } else { + auto obj = core.compile_model(ie_cnn_network, hw_target, device_config); + exe = OVExeNetwork(obj, hw_target); + } + #ifndef NDEBUG - printDebugInfo(obj); + printDebugInfo(exe.Get()); #endif - OVExeNetwork exe(obj); + return exe; } catch (const Exception& e) { ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); @@ -111,7 +181,7 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, #ifndef NDEBUG printDebugInfo(obj); #endif - OVExeNetwork exe(obj); + OVExeNetwork exe(obj, hw_target); return exe; } catch (const Exception& e) { ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); @@ -128,9 +198,9 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream, ov::CompiledModel obj; obj = core.import_model(model_stream, hw_target, device_config); #ifndef NDEBUG - printDebugInfo(obj); + printDebugInfo(exe.Get()); #endif - OVExeNetwork exe(obj); + OVExeNetwork exe(obj, hw_target); return exe; } catch (const Exception& e) { ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); @@ -192,11 +262,16 @@ void OVCore::SetStreams(const std::string& device_type, int num_streams) { core.set_property(device_type, {ov::num_streams(num_streams)}); } -OVInferRequest OVExeNetwork::CreateInferRequest() { +std::shared_ptr OVExeNetwork::CreateInferRequest() { try { - auto infReq = obj.create_infer_request(); - OVInferRequest inf_obj(std::move(infReq)); - return inf_obj; + auto infReq = compiled_model_obj.create_infer_request(); + std::shared_ptr ovInfReq; + if (is_stateful_causallm) { + ovInfReq = std::make_shared(std::move(infReq), target_device); + } else { + ovInfReq = std::make_shared(std::move(infReq)); + } + return ovInfReq; } catch (const Exception& e) { ORT_THROW(log_tag + "Exception while creating InferRequest object: " + e.what()); } catch (...) { @@ -245,9 +320,9 @@ void OVInferRequest::StartAsync() { try { ovInfReq.start_async(); } catch (const Exception& e) { - ORT_THROW(log_tag + " Couldn't start Inference: " + e.what()); + throw std::runtime_error(log_tag + " Couldn't start Inference: " + e.what()); } catch (...) { - ORT_THROW(log_tag + " In Error Couldn't start Inference"); + throw std::runtime_error(log_tag + " In Error Couldn't start Inference"); } } @@ -255,9 +330,9 @@ void OVInferRequest::Infer() { try { ovInfReq.infer(); } catch (const Exception& e) { - ORT_THROW(log_tag + " Couldn't start Inference: " + e.what()); + throw std::runtime_error(log_tag + " Couldn't start Inference: " + e.what()); } catch (...) { - ORT_THROW(log_tag + " In Error Couldn't start Inference"); + throw std::runtime_error(log_tag + " In Error Couldn't start Inference"); } } @@ -279,5 +354,160 @@ void OVInferRequest::QueryStatus() { std::cout << "ovInfReq.query_state()" << " "; } + +StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) + : OVInferRequest(std::move(infer_request)), target_device(device) { + bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); + if (gpu_or_npu) { + prefill_use_full_chat_history = true; + } +} + +void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type, + const std::vector& shape, int32_t fill_value) { + ov::Tensor tensor = ov::Tensor(type, shape); + std::fill_n(tensor.data(), tensor.get_size(), fill_value); + ovInfReq.set_tensor(tensor_name, tensor); +} + +void StatefulOVInferRequest::CacheTensor(const std::string& tensor_name, std::vector& cache) { + auto tensor = ovInfReq.get_tensor(tensor_name); + auto* pData = tensor.data(); + for (size_t i = 0; i < tensor.get_size(); i++) { + cache.emplace_back(pData[i]); + } +} + +void StatefulOVInferRequest::SetTensorFromCache(const std::string& tensor_name, + const std::vector& cache_data) { + auto tensor = ovInfReq.get_tensor(tensor_name); + auto new_shape = tensor.get_shape(); + new_shape[1] = cache_data.size(); + + auto new_tensor = ov::Tensor(tensor.get_element_type(), new_shape); + auto* pNewData = new_tensor.data(); + std::memcpy(pNewData, cache_data.data(), cache_data.size() * sizeof(int64_t)); + + ovInfReq.set_tensor(tensor_name, new_tensor); +} + +std::optional StatefulOVInferRequest::FindTensor(const std::string& tensor_name) { + // Check if tensor exists by examining input names in the compiled model + const auto& model = ovInfReq.get_compiled_model(); + bool tensor_exists = false; + + for (const auto& input : model.inputs()) { + const auto& names = input.get_names(); + if (names.find(tensor_name) != names.end()) { + tensor_exists = true; + break; + } + } + + if (tensor_exists) { + return ovInfReq.get_tensor(tensor_name); + } + + return std::nullopt; +} + +void StatefulOVInferRequest::PreProcessInferRequest() { + // Workaround: Setting the value here as it cannot be set at the ORT GenAI layer currently. + // TODO(ankit): Address this issue and implement the fix at the appropriate layer. + FillTensor("beam_idx", ov::element::i32, {1}, 0); + + // If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids. + if (prefill_use_full_chat_history) { + auto input_ids_tensor = ovInfReq.get_tensor("input_ids"); + CacheTensor("input_ids", cached_input_ids); + + // "position_ids" (GQA with Rotary Embeddings doesnt have position_ids) - check if exists + auto position_ids_opt = FindTensor("position_ids"); + bool has_position_ids = position_ids_opt.has_value(); + + if (has_position_ids) { + CacheTensor("position_ids", cached_position_ids); + } + + // If we're about to run the prefill model + if (input_ids_tensor.get_size() > 1) { + // Check if the size of the current "input_ids" tensor does not match the size of the cached "input_ids". + // This indicates that we are running a subsequent prompt (not the initial prefill). + if (input_ids_tensor.get_shape()[1] != cached_input_ids.size()) { + // Clear the internal KVCache state. For NPU device, this operation is a no-op. + ovInfReq.reset_state(); + + // Set tensors using cached values + SetTensorFromCache("input_ids", cached_input_ids); + + // Only set position_ids if it exists and we have cached values + if (has_position_ids && !cached_position_ids.empty()) { + SetTensorFromCache("position_ids", cached_position_ids); + } + } + } + } +} + +void StatefulOVInferRequest::StartAsync() { + PreProcessInferRequest(); + OVInferRequest::StartAsync(); +} + +void StatefulOVInferRequest::Infer() { + PreProcessInferRequest(); + OVInferRequest::Infer(); +} + +void StatefulOVInferRequest::RewindKVCache(size_t index) { + LOGS_DEFAULT(INFO) << log_tag << "RewindKVCache: Rewinding OpenVINO-internal KVCache state to index=" << index; + + if (prefill_use_full_chat_history) { + // Clear the internal KVCache state. For NPU device, this operation is a no-op. + ovInfReq.reset_state(); + + // Resize the cached "input_ids" and "position_ids" to the specified index. + if (cached_input_ids.size() > index) { + cached_input_ids.resize(index); + } + + if (cached_position_ids.size() > index) { + cached_position_ids.resize(index); + } + } else { + if (index == 0) { + // In this case, since we're resetting the entire KVCache, simply reset the state. + ovInfReq.reset_state(); + } else { + // Retrieve KVCache states and trim them to the specified index. + // The following logic is adapted from: + // https://github.com/openvinotoolkit/openvino.genai/blob/releases/2025/1/src/cpp/src/utils.cpp#L329 + auto states = ovInfReq.query_state(); + for (auto& state : states) { + ov::Tensor old_tensor = state.get_state(); + // Tensor shape: [batch_size, num_kv_heads, seq_len, head_size] + auto shape = old_tensor.get_shape(); + + if (shape[2] > index) { + // Update the sequence length dimension to the specified index. + shape[2] = index; + + ov::Coordinate new_shape_begin{0, 0, 0, 0}; + ov::Coordinate new_shape_end{shape}; + + // Create a trimmed tensor with the updated shape. + auto trimmed_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end); + + // Copy the trimmed tensor into a new tensor and update the state. + ov::Tensor new_tensor(old_tensor.get_element_type(), shape); + trimmed_tensor.copy_to(new_tensor); + + state.set_state(new_tensor); + } + } + } + } +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 866f4a02f7780..c3d165b40840c 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -61,10 +61,14 @@ struct OVCore : WeakSingleton { // OV Interface For Reading Model std::shared_ptr ReadModel(std::string&& model_stream, const std::string& model_path); + OVExeNetwork StatefulCompileModel(std::shared_ptr& model, + std::string& hw_target, + const ov::AnyMap& device_config); // OV Interface for Compiling OV Model Type OVExeNetwork CompileModel(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, + bool enable_causallm, const std::string& name); // OV Interface for Fast Compile OVExeNetwork CompileModel(const std::string& onnx_model, @@ -83,16 +87,20 @@ struct OVCore : WeakSingleton { }; class OVExeNetwork { - ov::CompiledModel obj; + ov::CompiledModel compiled_model_obj; + std::string target_device; + bool is_stateful_causallm; public: - explicit OVExeNetwork(ov::CompiledModel md) : obj(md) {} - OVExeNetwork() : obj(ov::CompiledModel()) {} - ov::CompiledModel& Get() { return obj; } - OVInferRequest CreateInferRequest(); + explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false) + : compiled_model_obj(compiled_model), target_device(device), is_stateful_causallm(stateful_causallm) {} + OVExeNetwork() : compiled_model_obj(ov::CompiledModel()) {} + ov::CompiledModel& Get() { return compiled_model_obj; } + std::shared_ptr CreateInferRequest(); }; class OVInferRequest { + protected: ov::InferRequest ovInfReq; public: @@ -100,16 +108,42 @@ class OVInferRequest { OVTensorPtr GetTensor(const std::string& name); std::string GetInputTensorName(uint32_t index); void SetTensor(const std::string& name, OVTensorPtr& blob); - void StartAsync(); - void Infer(); + virtual void StartAsync(); + virtual void Infer(); void WaitRequest(); void CancelRequest(); void QueryStatus(); - explicit OVInferRequest(ov::InferRequest obj) : ovInfReq(std::move(obj)) {} + explicit OVInferRequest(ov::InferRequest infer_request_obj) : ovInfReq(std::move(infer_request_obj)) {} OVInferRequest() : ovInfReq(ov::InferRequest()) {} ov::InferRequest& GetNewObj() { return ovInfReq; } + virtual void RewindKVCache(size_t index) {} +}; + +class StatefulOVInferRequest : public OVInferRequest { + public: + explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device); + + void StartAsync() override; + void Infer() override; + void RewindKVCache(size_t index) override; + void FillTensor(const std::string& tensor_name, const ov::element::Type& type, + const std::vector& shape, int32_t fill_value); + void CacheTensor(const std::string& tensor_name, std::vector& cache); + void SetTensorFromCache(const std::string& tensor_name, const std::vector& cache_data); + std::optional FindTensor(const std::string& tensor_name); + + private: + void PreProcessInferRequest(); + std::string target_device; + + // If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors, + // and ensure that full chat history is passed for each prefill call. + bool prefill_use_full_chat_history = false; + std::vector cached_input_ids; + std::vector cached_position_ids; }; + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc new file mode 100644 index 0000000000000..67ba42884e4f0 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -0,0 +1,350 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "core/providers/openvino/ov_stateful_patch_utils.h" + +namespace onnxruntime { +namespace openvino_ep { + +void LogBasicModelInfo(const std::shared_ptr& model) { + std::cout << "Model Name: " << model->get_friendly_name() << std::endl; + + // Log detailed information about model inputs and outputs + auto inputs = model->inputs(); + auto outputs = model->outputs(); + + std::cout << "\tInputs: " << std::endl; + for (const ov::Output& input : inputs) { + const std::string name = input.get_any_name(); + const ov::element::Type type = input.get_element_type(); + const ov::PartialShape shape = input.get_partial_shape(); + const ov::Layout layout = ov::layout::get_layout(input); + + std::cout << "\t\t" << name << ", " << type << ", " << shape << ", " << layout.to_string() << std::endl; + } + + std::cout << "\tOutputs: " << std::endl; + for (const ov::Output& output : outputs) { + const std::string name = output.get_any_name(); + const ov::element::Type type = output.get_element_type(); + const ov::PartialShape shape = output.get_partial_shape(); + const ov::Layout layout = ov::layout::get_layout(output); + + std::cout << "\t\t" << name << ", " << type << ", " << shape << ", " << layout.to_string() << std::endl; + } + + return; +} + +bool ModelHasInputOutputNames(std::shared_ptr model, const std::string& name_to_match) { + for (const ov::Output& input : model->inputs()) { + auto& names = input.get_names(); + + for (auto& name : names) { + if (name == name_to_match) { + return true; + } + } + } + + for (const ov::Output& output : model->outputs()) { + auto& names = output.get_names(); + for (auto& name : names) { + if (name == name_to_match) { + return true; + } + } + } + + return false; +} + +void FuseCacheReorder(std::shared_ptr ov_model, + std::vector& not_kv_inputs, + const std::vector& key_value_input_names, + int gather_dim) { + if (ModelHasInputOutputNames(ov_model, "beam_idx")) { + throw std::runtime_error("Model already has fused cache"); + } + + std::string main_input_name = "inputs_embeds"; + if (ModelHasInputOutputNames(ov_model, "input_ids")) { + main_input_name = "input_ids"; + } + + auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; + + auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape({input_batch})); + beam_idx->set_friendly_name("beam_idx"); + beam_idx->output(0).get_tensor().add_names({"beam_idx"}); + ov_model->add_parameters({beam_idx}); + not_kv_inputs.push_back(beam_idx->get_friendly_name()); + + // Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx + for (const auto& input_name : key_value_input_names) { + auto parameter_output_port = ov_model->input(input_name); + auto consumers = parameter_output_port.get_target_inputs(); + + auto gather_op = + std::make_shared(parameter_output_port, + beam_idx, + ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim})); + + // Replace the source output for all consumers of the input tensor + for (auto& consumer : consumers) { + consumer.replace_source_output(gather_op->output(0)); + } + } + + // Validate the modified model + ov_model->validate_nodes_and_infer_types(); +} + +void MakeStateful(std::shared_ptr& ov_model, + const std::vector& key_value_input_names, + const std::vector& key_value_output_names) { + std::map input_output_map; + + // Create mapping for KV-cache inputs and outputs + for (size_t i = 0; i < key_value_input_names.size(); ++i) { + input_output_map[key_value_input_names[i]] = key_value_output_names[i]; + } + + // Apply the transformation to make the model stateful + ov::pass::Manager manager; + manager.register_pass(input_output_map); + manager.run_passes(ov_model); +} + +// Converted to C++ from below reference URL: +// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281 +void PatchStatefulDecoder(std::shared_ptr model) { + std::vector key_value_input_names; + std::vector not_kv_inputs; + for (const ov::Output& input : model->inputs()) { + auto& names = input.get_names(); + + bool found = false; + for (auto& name : names) { + if (name.find("key_values") != std::string::npos) { + key_value_input_names.push_back(name); + found = true; + break; + } + } + + if (!found) { + not_kv_inputs.push_back(input.get_any_name()); + } + } + + std::vector key_value_output_names; + for (const ov::Output& output : model->outputs()) { + auto& names = output.get_names(); + for (auto& name : names) { + if (name.find("present") != std::string::npos) { + key_value_output_names.push_back(name); + break; + } + } + } + + if (key_value_input_names.empty() || key_value_output_names.empty()) { + std::cout << "no key_value_input_names or key_value_output_names found" << std::endl; + return; + } + + // By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch + // TODO(ryan): Deduce from a model via ordinal reshape(? ) and topology + // batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0 + auto batch_dim = 0; + + FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim); + + MakeStateful(model, key_value_input_names, key_value_output_names); +} + +// Some other utility functions copied from OpenVINO GenAI +bool HasOpWithType(const std::shared_ptr& function, const std::string& type_name) { + for (const auto& op : function->get_ops()) { + if (op->get_type_name() == type_name) { + return true; + } + } + return false; +} + +std::tuple, int64_t> FindLLMMatmul(const std::shared_ptr& model) { + auto last_node = model->output(0).get_node()->input_value(0).get_node_shared_ptr(); + std::shared_ptr matmul = ov::as_type_ptr(last_node); + + // In the case of PagedAttention, all tokens are moved to the batch dimension, + // and slicing/gathering must be performed accordingly. + const bool pa_based_model = HasOpWithType(model, "PagedAttentionExtension"); + int64_t slice_gather_dim = pa_based_model ? 0 : 1; + + // There are several patterns for MatMul we are looking for: + // MatMul -> Result + // MatMul -> Add -> Result + // MatMul -> Transpose -> Result + // MatMul -> Divide -> Tanh -> Multiply -> Result + // MatMul -> Convert -> Result + if (!matmul) { + if (auto add = ov::as_type_ptr(last_node)) { + matmul = ov::as_type_ptr(add->input_value(0).get_node_shared_ptr()); + } else if (auto transpose = ov::as_type_ptr(last_node)) { + matmul = ov::as_type_ptr(transpose->input_value(0).get_node_shared_ptr()); + auto order = ov::as_type_ptr(transpose->input_value(1).get_node_shared_ptr())->get_axis_vector_val(); + slice_gather_dim = order[slice_gather_dim]; + } else if (auto multiply = ov::as_type_ptr(last_node)) { + if (auto tanh = ov::as_type_ptr(multiply->input_value(0).get_node_shared_ptr())) { + if (auto divide = ov::as_type_ptr(tanh->input_value(0).get_node_shared_ptr())) { + matmul = ov::as_type_ptr(divide->input_value(0).get_node_shared_ptr()); + } + } + } else if (auto convert = ov::as_type_ptr(last_node)) { + matmul = ov::as_type_ptr(convert->input_value(0).get_node_shared_ptr()); + } + } + return std::make_tuple(matmul, slice_gather_dim); +} + +void ApplySliceBeforeMatmulTransformation(std::shared_ptr model) { + std::shared_ptr matmul = nullptr; + int64_t slice_gather_dim = -1; + std::tie(matmul, slice_gather_dim) = FindLLMMatmul(model); + + if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) { + auto start = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto stop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-2}); + auto step = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{slice_gather_dim}); + auto slice = std::make_shared(matmul->input_value(0), start, stop, step, axis); + matmul->input(0).replace_source_output(slice); + } +} + +void UpdateConfig(ov::AnyMap& config, const std::pair& pair) { + if (config.count(pair.first) == 0) { + config.insert(pair); + } +} + +std::optional PopOption(ov::AnyMap& config, const std::string& option_name) { + if (auto it = config.find(option_name); it != config.end()) { + std::optional found = std::make_optional(it->second); + config.erase(it); + return found; + } + return std::nullopt; +} + +void RenameKey(ov::AnyMap& config, const std::string& old_key, const std::string& new_key) { + if (config.count(old_key) != 0) { + auto opt_value = PopOption(config, old_key); + config[new_key] = opt_value.value(); + } +} + +KVAxesPosition GetKVAxesPos(std::shared_ptr model) { + // Sequence length axis in key/values tensors. For most cases, the tensor shape is + // [batch_size, num_kv_heads, seq_len, head_size]. Therefore, the sequence length axis + // is usually at index 2, and the batch axis is at index 0. + KVAxesPosition kv_pos{0u, 2u}; + + // "ReadValue" node is KV cache representation in stateful model + std::string kv_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name); + + for (const auto& op : model->get_ops()) { + // Check input size, as in LoRA adapters case it could be 0 + if (op->get_type_name() != kv_node_type_name || op->get_input_size() < 1) { + continue; + } + + // Shape example: [-1,4,0,64] + auto shape = op->get_input_partial_shape(0); + + for (int64_t i = 0; i < shape.rank().get_length(); i++) { + // Find axis = 0. This would be sequence length axis. + if (shape[i] == 0) { + kv_pos.seq_len = i; + } else if (shape[i].is_dynamic()) { + // Dynamic axis is a batch + kv_pos.batch = i; + } + } + break; + } + + return kv_pos; +} + +void UpdateNPUConfig(ov::AnyMap& config, const KVAxesPosition& kv_pos, const KVDesc& kv_desc) { + UpdateConfig(config, {"NPU_USE_NPUW", "YES"}); + UpdateConfig(config, {"NPUW_LLM", "YES"}); + + UpdateConfig(config, {"NPUW_LLM_BATCH_DIM", kv_pos.batch}); + UpdateConfig(config, {"NPUW_LLM_SEQ_LEN_DIM", kv_pos.seq_len}); + + UpdateConfig(config, {"NPUW_LLM_MAX_PROMPT_LEN", kv_desc.max_prompt_len}); + UpdateConfig(config, {"NPUW_LLM_MIN_RESPONSE_LEN", kv_desc.min_response_len}); + + RenameKey(config, "++PREFILL_CONFIG", "++NPUW_LLM_PREFILL_CONFIG"); + RenameKey(config, "++GENERATE_CONFIG", "++NPUW_LLM_GENERATE_CONFIG"); + RenameKey(config, "PREFILL_CONFIG", "NPUW_LLM_PREFILL_CONFIG"); + RenameKey(config, "PREFILL_HINT", "NPUW_LLM_PREFILL_HINT"); + RenameKey(config, "GENERATE_CONFIG", "NPUW_LLM_GENERATE_CONFIG"); + RenameKey(config, "GENERATE_HINT", "NPUW_LLM_GENERATE_HINT"); + + const size_t npuw_context_len_threshold = 2048; + if ((kv_desc.max_prompt_len + kv_desc.min_response_len) >= npuw_context_len_threshold) { + // This improves accuracy for generation sequences that exceed 2k tokens. + config["++NPUW_LLM_PREFILL_CONFIG"] = ov::AnyMap{{"NPUW_DEVICES", "NPU,CPU"}, {"NPUW_ONLINE_AVOID", "P:SinCos/NPU"}}; + config["++NPUW_LLM_GENERATE_CONFIG"] = ov::AnyMap{{"NPUW_DEVICES", "NPU,CPU"}, {"NPUW_ONLINE_AVOID", "P:SinCos/NPU"}}; + } +} + +std::optional PopOptionNew(ov::AnyMap& config, const std::string& option_name) { + if (auto it = config.find(option_name); it != config.end()) { + std::optional found = std::make_optional(it->second); + config.erase(it); + return found; + } + return std::nullopt; +} + +std::optional PopIntAndCast(ov::AnyMap& config, const std::string& key) { + auto anyopt = PopOptionNew(config, key); + if (anyopt.has_value()) { + const auto any = anyopt.value(); + int64_t value; + // NB: Integer value coming from python has int64_t datatype + if (any.is()) { + value = any.as(); + } else if (any.is()) { + value = any.as(); + } else { + OPENVINO_THROW("Failed to extract " + key + ". Type mismatch: expected types: int or int64_t"); + } + if (value < 0) { + OPENVINO_THROW(key + " cannot be negative!"); + } + return std::make_optional(static_cast(value)); + } + return std::nullopt; +} + +bool IsStateful(const std::shared_ptr& model) { + for (auto&& ptr : model->get_ordered_ops()) { + if (ov::is_type(ptr) || + ov::is_type(ptr) || + ov::is_type(ptr) || + ov::is_type(ptr)) { + return true; + } + } + return false; +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h new file mode 100644 index 0000000000000..0b89c4ed02e13 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h @@ -0,0 +1,84 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/pass/manager.hpp" +#include "openvino/pass/make_stateful.hpp" +#include "openvino/opsets/opset13.hpp" + +namespace onnxruntime { +namespace openvino_ep { + +void LogBasicModelInfo(const std::shared_ptr& model); + +bool ModelHasInputOutputNames(std::shared_ptr model, const std::string& name_to_match); + +void FuseCacheReorder(std::shared_ptr ov_model, + std::vector& not_kv_inputs, + const std::vector& key_value_input_names, + int gather_dim); + +void MakeStateful(std::shared_ptr& ov_model, + const std::vector& key_value_input_names, + const std::vector& key_value_output_names); + +void PatchStatefulDecoder(std::shared_ptr model); + +bool HasOpWithType(const std::shared_ptr& function, const std::string& type_name); + +std::tuple, int64_t> FindLLMMatmul(const std::shared_ptr& model); + +void ApplySliceBeforeMatmulTransformation(std::shared_ptr model); + +void UpdateConfig(ov::AnyMap& config, const std::pair& pair); + +std::optional PopOption(ov::AnyMap& config, const std::string& option_name); + +void RenameKey(ov::AnyMap& config, const std::string& old_key, const std::string& new_key); + +struct KVAxesPosition { + size_t batch; + size_t seq_len; +}; + +KVAxesPosition GetKVAxesPos(std::shared_ptr model); + +struct KVDesc { + uint32_t max_prompt_len; + uint32_t min_response_len; +}; + +struct CausalLMConfig { + void ApplyConfig(const ov::AnyMap& external_config, ov::AnyMap& genai_config) { + if (external_config.find("MAX_PROMPT_LEN") != external_config.end()) { + max_prompt_len = external_config.at("MAX_PROMPT_LEN").as(); + } + if (external_config.find("MIN_RESPONSE_LEN") != external_config.end()) { + min_response_len = external_config.at("MIN_RESPONSE_LEN").as(); + } + genai_config["MAX_PROMPT_LEN"] = ov::Any(max_prompt_len); + genai_config["MIN_RESPONSE_LEN"] = ov::Any(min_response_len); + } + + unsigned int max_prompt_len = 1024; + unsigned int min_response_len = 128; +}; + +void UpdateNPUConfig(ov::AnyMap& config, const KVAxesPosition& kv_pos, const KVDesc& kv_desc); + +std::optional PopOptionNew(ov::AnyMap& config, const std::string& option_name); +std::optional PopIntAndCast(ov::AnyMap& config, const std::string& key); + +bool IsStateful(const std::shared_ptr& model); + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 61e5fa05c66c1..4dff0376fcd84 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -2047,7 +2047,7 @@ TEST(AttentionTest, AttentionPastState_dynamic) { test.AddInput("past", past_dims, past_data); test.AddReferenceOutputs("testdata/attention_past_state.onnx", 0.005f); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } #endif //! defined(__wasm__) diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 05136ec0750a1..e8eda5af1dc29 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -760,6 +760,15 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else { ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_qdq_optimizer' should be a boolean i.e. true or false. Default value is false.\n"); } + } else if (key == "enable_causallm") { + if (value == "true" || value == "True" || + value == "false" || value == "False") { + ov_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [OpenVINO] The value for the key 'enable_causallm' should be a boolean i.e. true or false." + " Default value is false. This provider option must be used with CausalLM Models viz. LLMs & SLMs only.\n"); + } } else if (key == "disable_dynamic_shapes") { if (value == "true" || value == "True" || value == "false" || value == "False") { @@ -817,7 +826,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ORT_THROW( "[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO." " ['device_type', 'device_id', 'num_of_threads', 'load_config', 'cache_dir', 'num_streams', " - "'enable_opencl_throttling', 'disable_dynamic_shapes', 'enable_qdq_optimizer', 'model_priority'] \n"); + "'enable_opencl_throttling', 'disable_dynamic_shapes', 'enable_qdq_optimizer'," + " 'enable_causallm', 'model_priority'] \n"); } } session_options.AppendExecutionProvider_OpenVINO_V2(ov_options); From 5ec3c7ee2ba2d4d6f81c2f1319c882f5fe0acde8 Mon Sep 17 00:00:00 2001 From: Vishnudas Thaniel S Date: Tue, 10 Jun 2025 09:25:42 -0700 Subject: [PATCH 044/138] OVEP unittest (#685) * test:Added OVEP unittest for ep_context feature * Comment out the device selection, so that the device chosen during build time will be used --- cmake/onnxruntime_unittests.cmake | 7 ++ .../openvino/openvino_ep_context_test.cc | 87 +++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 onnxruntime/test/providers/openvino/openvino_ep_context_test.cc diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 15cc238173f29..c8de91d6c6eb6 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -770,6 +770,13 @@ if(onnxruntime_USE_AZURE) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_azure) endif() +if (onnxruntime_USE_OPENVINO) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/openvino/*) + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_openvino) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_openvino) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_shared) +endif() + file(GLOB onnxruntime_test_framework_src CONFIGURE_DEPENDS ${onnxruntime_test_framework_src_patterns} ) diff --git a/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc b/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc new file mode 100644 index 0000000000000..bf5d2d57727a6 --- /dev/null +++ b/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/framework/provider_options.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/float16.h" + +#include "test/util/include/test_utils.h" +#include "test/util/include/test/test_environment.h" +#include "test/util/include/default_providers.h" + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/inference_session.h" +#include "core/graph/model_saving_options.h" + +#include "test/optimizer/qdq_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + + +extern std::unique_ptr ort_env; + +class OVEPEPContextTests : public ::testing::Test { + + +}; + +namespace onnxruntime { +namespace test { + + +// Test if folder path given to ep_context_file_path throws an error +TEST_F(OVEPEPContextTests, OVEPEPContextFolderPath) { + + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + + //The below line could fail the test in non NPU platforms.Commenting it out so that the device used for building OVEP will be used. + //ov_options["device_type"] = "NPU"; + + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("OVEP_Test_Model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string ep_context_file_path = "./ep_context_folder_path/"; + + + sessionOptions.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + sessionOptions.AddConfigEntry(kOrtSessionOptionEpContextFilePath,ep_context_file_path.c_str()); + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + + + try { + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), sessionOptions); + FAIL(); // Should not get here! + } catch (const Ort::Exception& excpt) { + ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT); + ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder.")); + } + +} + + +} // namespace test +} // namespace onnxruntime From ca06b7a39c814a437c4efa59bcc1fbca0834e2af Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Fri, 13 Jun 2025 21:59:57 -0700 Subject: [PATCH 045/138] Support for bounded dynamic model (#701) * Refactored the code for reshape feature * Refactor the inference logic accomodating bounded dimensions * Fix lint issues * Refactor OV shapes classification to be a part of bindings struct * Refactor the provider options key verification for python interface * Restrict removal of model proto when CPU fallback is enabled and fix unit test failures --------- Co-authored-by: jatinwadhwa921 --- .../providers/openvino/backend_manager.cc | 124 +++++++++-- .../core/providers/openvino/backend_manager.h | 4 + .../core/providers/openvino/backend_utils.cc | 4 + .../openvino/backends/basic_backend.cc | 193 ++++++++++-------- .../openvino/backends/basic_backend.h | 54 ++++- .../core/providers/openvino/contexts.h | 5 +- .../openvino/openvino_execution_provider.cc | 6 +- .../openvino/openvino_parser_utils.cc | 120 +++++++++++ .../openvino/openvino_parser_utils.h | 4 + .../openvino/openvino_provider_factory.cc | 6 +- .../core/providers/openvino/ov_allocator.cc | 12 +- .../core/providers/openvino/ov_interface.cc | 4 +- .../core/providers/openvino/ov_interface.h | 5 +- .../python/onnxruntime_pybind_state.cc | 60 +----- .../quantization/matmul_nbits_quantizer.py | 22 +- onnxruntime/test/perftest/ort_test_session.cc | 2 + .../openvino/openvino_ep_context_test.cc | 16 +- 17 files changed, 445 insertions(+), 196 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index c22f2e9cc0fa1..cb7acfd2ca95a 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -70,6 +70,9 @@ BackendManager::BackendManager(SessionContext& session_context, // Save the indexes of graph inputs among fused_node's inputDefs // (which also contains initializers). for (uint32_t index = 0; const auto& node : subgraph.GetInputs()) { + if (subgraph.GetGraph().GetConsumerNodes(node->Name()).size() == 0) { + continue; // Skip if the input is a dangling node + } subgraph_context_.input_names.insert({node->Name(), index++}); } @@ -110,7 +113,7 @@ BackendManager::BackendManager(SessionContext& session_context, subgraph_context_.has_dynamic_input_shape = true; LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; if (cpu_or_gpu || (npu && session_context_.enable_causallm) && - !session_context_.disable_dynamic_shapes) { + !session_context_.disable_dynamic_shapes) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; try { @@ -291,24 +294,83 @@ bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& mod } bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const { - bool has_sym_dims = false; - auto graph_inputs = subgraph.GetInputs(); - for (auto input : graph_inputs) { + const auto& graph_inputs = subgraph.GetInputs(); + + // First validate shapes if provided by user + bool shapes_valid = true; + if (!session_context_.reshape.empty()) { + try { + ValidateInputShapes(session_context_.reshape, graph_inputs); + } catch (const std::exception& e) { + LOGS_DEFAULT(ERROR) << "[OpenVINO-EP] Shape validation failed: " << e.what(); + session_context_.reshape.clear(); // Clear the shape map as it's invalid + shapes_valid = false; + } + } + + // Count dynamic inputs and check if reshape covers all of them + size_t dynamic_input_count = 0; + bool all_dynamic_inputs_covered = true; + + for (const auto* input : graph_inputs) { + // Skip dangling inputs (no consumers) + if (subgraph.GetGraph().GetConsumerNodes(input->Name()).empty()) { + continue; + } + + // Check if input has dynamic dimensions + bool has_dynamic_dim = false; + + // Case 1: Completely undefined shape if (input->Shape() == nullptr) { - has_sym_dims = true; - break; + has_dynamic_dim = true; } - for (auto& dim : input->Shape()->dim()) { - if (dim.value_case() != dim.kDimValue) { - has_sym_dims = true; - break; + // Case 2: Shape defined but with symbolic dimensions + else { + for (const auto& dim : input->Shape()->dim()) { + if (dim.value_case() != dim.kDimValue) { + has_dynamic_dim = true; + break; + } } } - if (has_sym_dims) { - break; + + // If dynamic, count it and check if reshape covers it + if (has_dynamic_dim) { + dynamic_input_count++; + + // Check if this dynamic input is covered by reshape input + if (!session_context_.reshape.empty() && + session_context_.reshape.find(input->Name()) == session_context_.reshape.end()) { + all_dynamic_inputs_covered = false; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] reshape_input is provided but doesn't cover dynamic input: " + << input->Name(); + } } } - return has_sym_dims; + + const bool has_symbolic_dims = (dynamic_input_count > 0); + + // Early return if no reshape input provided + if (session_context_.reshape.empty()) { + return has_symbolic_dims; // Return based on whether model has symbolic dims + } + + // For dynamic models with incomplete reshape coverage, clear shapes + if (has_symbolic_dims && !all_dynamic_inputs_covered) { + session_context_.reshape.clear(); + LOGS_DEFAULT(WARNING) << "reshape_input does not cover all dynamic dimensions, " + << "ignoring all provided shapes"; + return true; // Model is dynamic + } + + // If shapes are valid with complete coverage for dynamic model, treat as concrete + if (has_symbolic_dims && shapes_valid && all_dynamic_inputs_covered) { + LOGS_DEFAULT(INFO) << "All dynamic dimensions successfully covered by reshape_input"; + return false; // Model is now effectively static with concrete shapes + } + + return has_symbolic_dims; // Return dynamic status based on symbolic dimensions } // Check to see if the graph is QDQ @@ -386,7 +448,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, const auto& onnx_model_path_name = subgraph.ModelPath(); // QDQ stripping enabled only for the NPU and experimentally on the GPU if ((session_context_.device_type.find("NPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos) && + session_context_.device_type.find("GPU") != std::string::npos) && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights); @@ -480,6 +542,40 @@ BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_p return model_copy; } +void BackendManager::ValidateInputShapes(const reshape_t& shapes, + const std::vector& graph_inputs) const { + for (const auto& [tensor_name, requested_shape] : shapes) { + // Find matching input in graph + const NodeArg* graph_input = nullptr; + for (const auto* input : graph_inputs) { + if (input->Name() == tensor_name) { + graph_input = input; + break; + } + } + + if (!graph_input) { + ORT_THROW("Input '" + tensor_name + "' specified in reshape_input does not exist in the graph"); + } + + const ONNX_NAMESPACE::TensorShapeProto* graph_shape = graph_input->Shape(); + if (!graph_shape) { + ORT_THROW("Graph input '" + tensor_name + "' has no shape information"); + } + + // Check dimensions count matches + size_t graph_dim_count = graph_shape->dim_size(); + size_t requested_dim_count = requested_shape.get_max_shape().size(); + + if (graph_dim_count != requested_dim_count) { + ORT_THROW("Dimensions mismatch for input '" + tensor_name + + "': graph expects " + std::to_string(graph_dim_count) + + " dimensions but reshape_input specifies " + + std::to_string(requested_dim_count) + " dimensions"); + } + } +} + void BackendManager::Compute(OrtKernelContext* context) { Ort::KernelContext ctx(context); std::chrono::high_resolution_clock::time_point start_compute, end_compute; diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index 799dc50dd7a63..7165b9cf2e14c 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -39,7 +39,11 @@ class BackendManager { const logging::Logger& logger) const; bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const; + std::unordered_set IdentifyDynamicInputs(const onnxruntime::GraphViewer& subgraph, + const std::vector& graph_inputs) const; bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; + void ValidateInputShapes(const reshape_t& shapes, + const std::vector& graph_inputs) const; std::shared_ptr ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_proto); diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 1382c187f6b4e..49eedfb3e4fcd 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -146,6 +146,10 @@ CreateOVModel(std::string&& model, try { auto ov_model = OVCore::Get()->ReadModel(std::move(model), session_context.onnx_model_path_name.string()); + if (!session_context.reshape.empty()) { + LOGS_DEFAULT(INFO) << log_tag << "Reshaping the ov tensor to specified shape"; + ov_model->reshape(session_context.reshape); + } // Check for Constant Folding if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) { ov::pass::ConstantFolding pass_const_obj; diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 7902b3edb2276..3105c307706ad 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -67,6 +67,9 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) || (session_context_.OpenVINO_Version.at(0) >= 2024 && session_context_.OpenVINO_Version.at(1) > 2)); + bool disable_cpu_fallback = !(hw_target.find("NPU") != std::string::npos && + !session_context_.so_disable_cpu_ep_fallback && + !subgraph_context_.is_ep_ctx_graph); if (subgraph_context_.is_ep_ctx_graph) { // If the blob is held in an EPContext node, then skip FE+Compile // and directly move on to creating a backend with the executable blob @@ -78,14 +81,16 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr } else if (!session_context_.has_external_weights && !subgraph_context_.has_dynamic_input_shape && !session_context_.so_context_enable && - !enable_causallm && auto_unified_compile) { + session_context_.reshape.empty() && + !enable_causallm && + auto_unified_compile) { // Unified OV compile_model is efficient when ov model caching is enabled // Unified OV compile_model API is supported with AUTO from version 2024.3 and above // Inputs with static dimensions // Not enabled for models with external weights and when ep context is set. const std::string model = model_proto->SerializeAsString(); // we have the serialized string, so we can release model proto to lower the peak memory consumption - model_proto.reset(); + if (disable_cpu_fallback) model_proto.reset(); exe_network_ = OVCore::Get()->CompileModel(model, hw_target, device_config, @@ -93,7 +98,9 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr } else { // For all other types use ov::ov_core read_model() to generate OV IR // followed by ov::ov_core compile_model() std::string model = model_proto->SerializeAsString(); - if (!subgraph_context.has_dynamic_input_shape) { + // Reset model proto only when cpu fallback is disabled or when the model has dynamic input shapes. + // This is to avoid memory peak usage when the model is large. + if (!subgraph_context.has_dynamic_input_shape && disable_cpu_fallback) { model_proto.reset(); } auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); @@ -351,6 +358,26 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) { device_config.emplace(ov::inference_num_threads(session_context_.num_of_threads)); } +void BasicBackend::ValidateOrtDimsAgainstPartialShape(const std::vector& ort_dims, + const ov::PartialShape& partial_shape) const { + // Check if the number of dimensions matches + if (static_cast(ort_dims.size()) != partial_shape.rank().get_length()) { + ORT_THROW("Mismatch in number of dimensions between ORT tensor and OpenVINO PartialShape."); + } + // Validate each dimension + for (size_t i = 0; i < ort_dims.size(); ++i) { + const auto& ov_dim = partial_shape[i]; // OpenVINO dimension at index i + int64_t ort_dim = ort_dims[i]; // ORT dimension at index i + + // Check if the ORT dimension is within the specified range + int64_t min_dim = ov_dim.get_min_length(); + int64_t max_dim = ov_dim.get_max_length(); + if (ort_dim < min_dim || ort_dim > max_dim) { + ORT_THROW(" ORT Dimension is out of range"); + } + } +} + void BasicBackend::RewindKVCache(size_t index) { OVInferRequestPtr infer_request; infer_request = inferRequestsQueue_->getIdleRequest(); @@ -362,95 +389,91 @@ void BasicBackend::RewindKVCache(size_t index) { // an Infer Request indexed by infer_req_idx void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { try { - bool cpu_or_gpu = (session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos); - bool npu = (session_context_.device_type.find("NPU") != std::string::npos); + const bool is_cpu = session_context_.device_type.find("CPU") != std::string::npos; + const bool is_gpu = session_context_.device_type.find("GPU") != std::string::npos; + const bool is_npu = session_context_.device_type.find("NPU") != std::string::npos; + const bool is_cpu_or_gpu = is_cpu || is_gpu; + // Loop over subgraph original input names to find the correspondent OV input name for (const auto& input_info : bindings_->network_inputs_) { size_t batch_slice_idx = 0; - if (subgraph_context_.has_dynamic_input_shape && - !session_context_.disable_dynamic_shapes && - cpu_or_gpu || (npu && session_context_.enable_causallm)) { - auto tensor = context.GetInput(input_info.onnx_index); - auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); - auto tensor_shape = tensor_info.GetShape(); - auto tensor_size = tensor_shape.size(); - const char* tensor_data = tensor.GetTensorData(); - auto tensor_iter = 0; - ov::Shape input_tensor_shape = ov::Shape(tensor_size, 0); - for (auto i = tensor_shape.begin(); i != tensor_shape.end(); ++i) { - input_tensor_shape[tensor_iter] = *i; - tensor_iter += 1; - } - OVTensorPtr tensor_ptr; - // avoid input copies on the CPU device - if (session_context_.device_type.find("CPU") != std::string::npos) { - tensor_ptr = std::make_shared(input_info.type, input_tensor_shape, - (void*)tensor_data); - } else { - tensor_ptr = std::make_shared(input_info.type, input_tensor_shape); - FillInputBlob(tensor_ptr, batch_slice_idx, input_info.name, context, subgraph_context_); - } - - try { - infer_request->SetTensor(input_info.name, tensor_ptr); - } catch (const char* msg) { - ORT_THROW(msg); - } - } else { - if (cpu_or_gpu) { - OVTensorPtr graph_input_blob; + auto tensor = context.GetInput(input_info.onnx_index); + auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); + auto tensor_shape = tensor_info.GetShape(); + auto tensor_data = tensor.GetTensorData(); + if (input_info.IsBoundedDynamic()) { + ov::PartialShape partial_shape = input_info.ov_shape; + ValidateOrtDimsAgainstPartialShape(tensor_shape, partial_shape); + } + ov::Shape input_tensor_shape(tensor_shape.begin(), tensor_shape.end()); + OVTensorPtr tensor_ptr; + if (is_cpu_or_gpu) { + if (input_info.IsStatic()) { try { - graph_input_blob = infer_request->GetTensor(input_info.name); + auto graph_input_blob = infer_request->GetTensor(input_info.name); + FillInputBlob(std::move(graph_input_blob), batch_slice_idx, input_info.name, context, subgraph_context_); } catch (const char* msg) { ORT_THROW(msg); } - FillInputBlob(std::move(graph_input_blob), batch_slice_idx, input_info.name, context, subgraph_context_); } else { - auto tensor = context.GetInput(input_info.onnx_index); - ort_tensor_key_t ort_tensor_key{input_info.name}; - auto it = ort_ov_tensor_map.find(ort_tensor_key); - if ((it == ort_ov_tensor_map.end()) || it->second.ort_ptr != tensor.GetTensorRawData()) { - ov_tensor_data_t ov_tensor_data; - ov_tensor_data.tensor_ptr = std::make_shared(input_info.type, input_info.ov_shape.get_shape(), - const_cast(tensor.GetTensorRawData())); - - ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); - ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; + if (is_cpu) { + tensor_ptr = std::make_shared(input_info.type, input_tensor_shape, (void*)tensor_data); + } else { // GPU + tensor_ptr = std::make_shared(input_info.type, input_tensor_shape); + FillInputBlob(tensor_ptr, batch_slice_idx, input_info.name, context, subgraph_context_); + } - try { - infer_request->SetTensor(input_info.name, ov_tensor_data.tensor_ptr); - } catch (const char* msg) { - ORT_THROW(msg); - } + try { + infer_request->SetTensor(input_info.name, tensor_ptr); + } catch (const char* msg) { + ORT_THROW(msg); } } - } - } // Loop subgraph original input + } else { // Other device path + ort_tensor_key_t ort_tensor_key{input_info.name}; + auto it = ort_ov_tensor_map.find(ort_tensor_key); - // For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path for NPU plugin as well. - if (npu && !session_context_.enable_causallm) { - // Set the output blob as remote blob - for (const auto& output_info : bindings_->network_outputs_) { - Ort::UnownedValue tensor = context.GetOutput(output_info.onnx_index, output_info.onnx_shape); - - ort_tensor_key_t ort_tensor_key{output_info.name}; - const auto& it = ort_ov_tensor_map.find(ort_tensor_key); - if ((it == ort_ov_tensor_map.end()) || (it->second.ort_ptr != tensor.GetTensorRawData())) { + if (it == ort_ov_tensor_map.end() || it->second.ort_ptr != tensor.GetTensorRawData()) { ov_tensor_data_t ov_tensor_data; - ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); - ov_tensor_data.tensor_ptr = std::make_shared(output_info.type, output_info.ov_shape.get_shape(), + ov_tensor_data.tensor_ptr = std::make_shared(input_info.type, input_tensor_shape, const_cast(tensor.GetTensorRawData())); + ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; try { - infer_request->SetTensor(output_info.name, ov_tensor_data.tensor_ptr); + infer_request->SetTensor(input_info.name, ov_tensor_data.tensor_ptr); } catch (const char* msg) { ORT_THROW(msg); } } } } + // Handle output + if (is_npu && !session_context_.enable_causallm) { + // Set the output blob as remote blob + for (const auto& output_info : bindings_->network_outputs_) { + if (output_info.IsStatic()) { + // Set remote tensor for static outputs only + Ort::UnownedValue tensor = context.GetOutput(output_info.onnx_index, output_info.onnx_shape); + + ort_tensor_key_t ort_tensor_key{output_info.name}; + const auto& it = ort_ov_tensor_map.find(ort_tensor_key); + if ((it == ort_ov_tensor_map.end()) || (it->second.ort_ptr != tensor.GetTensorRawData())) { + ov_tensor_data_t ov_tensor_data; + ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); + ov_tensor_data.tensor_ptr = std::make_shared(output_info.type, output_info.ov_shape.get_shape(), + const_cast(tensor.GetTensorRawData())); + ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; + + try { + infer_request->SetTensor(output_info.name, ov_tensor_data.tensor_ptr); + } catch (const char* msg) { + ORT_THROW(msg); + } + } + } + } + } // Start Async inference infer_request->StartAsync(); @@ -465,7 +488,7 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe // Wait for Async inference completion try { infer_request->WaitRequest(); - } catch(const std::runtime_error& e) { + } catch (const std::runtime_error& e) { infer_request->CancelRequest(); inferRequestsQueue_->deleteRequest(); ORT_THROW(log_tag + e.what()); @@ -474,20 +497,20 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos || session_context_.device_type.find("GPU") != std::string::npos; bool npu = session_context_.device_type.find("NPU") != std::string::npos; - if (cpu_or_gpu || (npu && session_context_.enable_causallm)) { - for (const auto& output_info : bindings_->network_outputs_) { - OVTensorPtr graph_output_blob; - try { - graph_output_blob = infer_request->GetTensor(output_info.name); - } catch (const char* msg) { - ORT_THROW(msg); - } - size_t batch_size = 1; - Ort::UnownedValue output_tensor = - GetOutputTensor(context, batch_size, infer_request, output_info.name, subgraph_context_.output_names); - auto mem_info = output_tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - return; + for (const auto& output_info : bindings_->network_outputs_) { + if (cpu_or_gpu || (npu && (session_context_.enable_causallm || !output_info.IsStatic()))) { + OVTensorPtr graph_output_blob; + try { + graph_output_blob = infer_request->GetTensor(output_info.name); + } catch (const char* msg) { + ORT_THROW(msg); + } + size_t batch_size = 1; + Ort::UnownedValue output_tensor = + GetOutputTensor(context, batch_size, infer_request, output_info.name, subgraph_context_.output_names); + auto mem_info = output_tensor.GetTensorMemoryInfo(); + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + return; } else { size_t batch_slice = 0; FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); @@ -549,7 +572,7 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { } else { OVInferRequestPtr infer_request; infer_request = inferRequestsQueue_->getIdleRequest(); - if(infer_request == nullptr) { + if (infer_request == nullptr) { ORT_THROW("OpenVINO Execution Provider :: There are no inference requests"); LOGS_DEFAULT(FATAL) << log_tag << "Create Infer Requests do not exist"; return; diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index fe178ccb5661b..8e76c9e69e223 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -38,7 +38,23 @@ struct OnnxToOvNetworkBindings { ov::element::Type type; ov::PartialShape ov_shape; std::vector onnx_shape; + uint8_t dynamic_flags = 0; // bit 0: fully_dynamic, bit 1: bounded_dynamic + + // Query methods + bool IsStatic() const { return dynamic_flags == 0; } + bool IsFullyDynamic() const { return dynamic_flags & 1; } + bool IsBoundedDynamic() const { return dynamic_flags & 2; } + bool IsMixed() const { return (dynamic_flags & 3) == 3; } + + // Setter methods + void SetFullyDynamic(bool value) { + dynamic_flags = value ? (dynamic_flags | 1) : (dynamic_flags & ~1); + } + void SetBoundedDynamic(bool value) { + dynamic_flags = value ? (dynamic_flags | 2) : (dynamic_flags & ~2); + } }; + std::vector network_outputs_; std::vector network_inputs_; @@ -52,8 +68,8 @@ struct OnnxToOvNetworkBindings { // However, these tensors are internally converted to a stateful representation, which removes them. // To prevent runtime exceptions, we simply continue processing here. if ((onnx_name.empty() || onnx_name == "beam_idx" || - onnx_name.find("past_key_values") != std::string::npos || - onnx_name.find("present") != std::string::npos) && + onnx_name.find("past_key_values") != std::string::npos || + onnx_name.find("present") != std::string::npos) && session_context.enable_causallm) { continue; } @@ -68,19 +84,40 @@ struct OnnxToOvNetworkBindings { auto type = ov_parameters[ov_param_index].get_element_type(); ParameterInfo info{onnx_name, ov_param_index, onnx_param_index, type, shape}; + // Analyze shape dynamism and set flags if (shape.is_static()) { + // dynamic_flags remains 0 (static) auto static_shape = shape.get_shape(); - std::transform(static_shape.begin(), static_shape.end(), std::back_inserter(info.onnx_shape), [](const auto& dim) { return static_cast(dim); }); + std::transform(static_shape.begin(), static_shape.end(), std::back_inserter(info.onnx_shape), + [](const auto& dim) { return static_cast(dim); }); + } else { + // Analyze dynamic dimensions + bool has_fully_dynamic = false; + bool has_bounded_dynamic = false; + + for (const auto& dim : shape) { + if (dim.is_dynamic()) { + if (dim.get_interval().has_upper_bound()) { + has_bounded_dynamic = true; + } else { + has_fully_dynamic = true; + } + } + } + + info.SetFullyDynamic(has_fully_dynamic); + info.SetBoundedDynamic(has_bounded_dynamic); } + input_output_map.push_back(std::move(info)); } }; + // Populate inputs and outputs populate(network_inputs_, subgraph_context.input_names, exec_network.Get().inputs()); populate(network_outputs_, subgraph_context.output_names, exec_network.Get().outputs()); } }; - class InferRequestsQueue; class BasicBackend : public IBackend { public: @@ -105,6 +142,8 @@ class BasicBackend : public IBackend { void EnableStreams(); void SetNumThreads(ov::AnyMap& device_config); void StartAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); + void ValidateOrtDimsAgainstPartialShape(const std::vector& ort_dims, + const ov::PartialShape& partial_shape) const; void CompleteAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); SessionContext& session_context_; @@ -123,7 +162,7 @@ class InferRequestsQueue { public: InferRequestsQueue(OVExeNetwork& net, size_t nireq, std::function initializer) { OVInferRequestPtr infer_request; - live_threads=nireq; + live_threads = nireq; for (size_t id = 0; id < nireq; id++) { infer_request = net.CreateInferRequest(); initializer(infer_request); @@ -155,7 +194,7 @@ class InferRequestsQueue { OVInferRequestPtr getIdleRequest() { std::unique_lock lock(_mutex); - if(live_threads==0) { + if (live_threads == 0) { return nullptr; } @@ -167,8 +206,7 @@ class InferRequestsQueue { void deleteRequest() { std::unique_lock lock(_mutex); - live_threads=live_threads-1; - std::cout << "delete Request" << live_threads << "\n"; + live_threads = live_threads - 1; } private: diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 2506d587dd3ad..09d48a5e916e1 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -12,6 +12,7 @@ #include #include "core/common/common.h" #include "core/providers/openvino/ov_interface.h" +#include "core/providers/shared_library/provider_api.h" namespace onnxruntime { namespace openvino_ep { @@ -67,6 +68,7 @@ class SharedContext : public WeakSingleton { }; using config_t = std::map; +using reshape_t = std::map; struct ProviderInfo { std::string device_type{""}; // [device_type]: Overrides the accelerator hardware type and @@ -84,6 +86,7 @@ struct ProviderInfo { // dump and load the blobs for the model caching/kernel caching // (GPU) feature. If blob files are already present, // it will be directly loaded. + reshape_t reshape{}; // Used for reshaping the ov input tensor shape at runtime. std::string model_priority{"DEFAULT"}; // High-level OpenVINO model priority hint // Defines what model should be provided with more performant // bounded resource first @@ -106,7 +109,7 @@ struct ProviderInfo { const ConfigOptions* config_options{NULL}; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", - "enable_causallm", "disable_dynamic_shapes"}; + "enable_causallm", "disable_dynamic_shapes", "reshape_input"}; }; // Holds context applicable to the entire EP instance. diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index d12f1edc57da5..5c8293a213f40 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -240,9 +240,9 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span +#include #include "core/providers/openvino/openvino_parser_utils.h" #include "core/providers/shared_library/provider_api.h" @@ -116,5 +117,124 @@ std::string OpenVINOParserUtils::ParsePrecision(const ProviderOptions& provider_ } } +reshape_t OpenVINOParserUtils::ParseInputShape(const std::string& reshape_input_definition) { + reshape_t parsed_shape_map; + + // Return empty map for empty input + if (reshape_input_definition.empty()) { + ORT_THROW("Empty input shape definition provided in reshape_input parameter"); + } + + // Regular expressions for parsing + const std::regex tensor_pattern(R"(([^\[\],]+)\s*\[(.*?)\])"); // e.g. "input_1[1..5, 2, 3..4],data[1,2,3]" + // const std::regex dimension_pattern(R"(\s*(\d+(?:\.\.\d+)?)\s*)"); // e.g. "1..5", "2", "3..4" + const std::regex dimension_pattern(R"(\s*([^,\s]+)\s*)"); + // Find all tensor shape definitions using regex + auto tensor_begin = std::sregex_iterator( + reshape_input_definition.begin(), + reshape_input_definition.end(), + tensor_pattern); + auto tensor_end = std::sregex_iterator(); + + // If no matches found, throw error + if (tensor_begin == tensor_end) { + ORT_THROW("Invalid input shape definition format: " + reshape_input_definition); + } + + // Process each tensor definition e.g. "input_1[1..5, 2, 3..4],data[1,2,3]" + for (std::sregex_iterator i = tensor_begin; i != tensor_end; ++i) { + std::smatch tensor_match = *i; + + // Extract tensor name and trim whitespace + std::string tensor_name = tensor_match[1].str(); // Group 1: tensor name e.g. "input_1" + tensor_name = TrimWhitespace(tensor_name); + + if (tensor_name.empty()) { + ORT_THROW("Empty tensor name provided in reshape_input parameter"); + } + + // Extract dimensions string + std::string dimensions_str = tensor_match[2].str(); // Group 2: dimensions string [e.g. "1..5, 2, 3..4"] + std::vector dimensions; + + // Find all dimension e.g. "1..5", "2", "3..4" using regex + auto dim_begin = std::sregex_iterator( + dimensions_str.begin(), + dimensions_str.end(), + dimension_pattern); + auto dim_end = std::sregex_iterator(); + + // Process each dimension + for (std::sregex_iterator j = dim_begin; j != dim_end; ++j) { + std::smatch dim_match = *j; + std::string dim_value = dim_match[1].str(); + + // Check if dimension is a range + size_t range_separator_pos = dim_value.find(".."); + if (range_separator_pos != std::string::npos) { + // Parse range + dimensions.push_back(ParseDimensionRange(dim_value, tensor_name)); + } else { + // Parse single value + bool is_valid_integer = !dim_value.empty() && + std::all_of(dim_value.begin(), dim_value.end(), [](char c) { + return std::isdigit(static_cast(c)); + }); + + if (!is_valid_integer) { + ORT_THROW("Invalid dimension value: '" + dim_value + "' for tensor: " + tensor_name); + } + + dimensions.push_back(std::stoi(dim_value)); + } + } + + // Store parsed shape in result map + parsed_shape_map[tensor_name] = ov::PartialShape(dimensions); + } + + return parsed_shape_map; +} + +// Helper function to trim whitespace from a string +std::string OpenVINOParserUtils::TrimWhitespace(const std::string& str) { + const std::string whitespace = " \t\n\r\f\v"; + size_t start = str.find_first_not_of(whitespace); + + if (start == std::string::npos) { + return ""; + } + + size_t end = str.find_last_not_of(whitespace); + return str.substr(start, end - start + 1); +} + +// Helper function to parse dimension range (e.g. "1..5") +ov::Dimension OpenVINOParserUtils::ParseDimensionRange(const std::string& range_str, const std::string& tensor_name) { + size_t range_separator_pos = range_str.find(".."); + if (range_separator_pos == std::string::npos) { + ORT_THROW("Invalid dimension range format: " + range_str); + } + + std::string range_start_str = TrimWhitespace(range_str.substr(0, range_separator_pos)); + std::string range_end_str = TrimWhitespace(range_str.substr(range_separator_pos + 2)); + + // Validate range values + if (range_start_str.empty() || range_end_str.empty() || + !std::all_of(range_start_str.begin(), range_start_str.end(), ::isdigit) || + !std::all_of(range_end_str.begin(), range_end_str.end(), ::isdigit)) { + ORT_THROW("Invalid dimension range format: '" + range_str + "' for tensor: " + tensor_name); + } + + int range_start = std::stoi(range_start_str); + int range_end = std::stoi(range_end_str); + + if (range_start > range_end) { + ORT_THROW("Invalid dimension range (start > end): " + range_str + " for tensor: " + tensor_name); + } + + return ov::Dimension(range_start, range_end); +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.h b/onnxruntime/core/providers/openvino/openvino_parser_utils.h index 3e23c9e788463..e6aa0e0a46a3b 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.h +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.h @@ -7,6 +7,7 @@ #include #include "core/framework/provider_options.h" +#include "core/providers/openvino/contexts.h" namespace onnxruntime { namespace openvino_ep { @@ -16,6 +17,9 @@ class OpenVINOParserUtils { static std::string ParsePrecision(const ProviderOptions& provider_options, std::string& device_type, const std::string& option_name); + static reshape_t ParseInputShape(const std::string& reshape_input_definition); + static std::string TrimWhitespace(const std::string& str); + static ov::Dimension ParseDimensionRange(const std::string& range_str, const std::string& tensor_name); }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index f7e64a9be2c60..85594d1c70dd3 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -233,6 +233,10 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, pi.precision = OpenVINOParserUtils::ParsePrecision(provider_options, pi.device_type, "precision"); + if (provider_options.contains("reshape_input")) { + pi.reshape = OpenVINOParserUtils::ParseInputShape(provider_options.at("reshape_input")); + } + if (provider_options.contains("load_config")) { auto parse_config = [&](const std::string& config_str) -> std::map { // If the config string is empty, return an empty map and skip processing @@ -349,7 +353,7 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, } catch (std::string msg) { ORT_THROW(msg); } - // Always true for NPU plugin or when passed . + if (pi.device_type.find("NPU") != std::string::npos) { // For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path. if (pi.enable_causallm) { diff --git a/onnxruntime/core/providers/openvino/ov_allocator.cc b/onnxruntime/core/providers/openvino/ov_allocator.cc index 1bbe71441cbff..9e4ac6009e2e3 100644 --- a/onnxruntime/core/providers/openvino/ov_allocator.cc +++ b/onnxruntime/core/providers/openvino/ov_allocator.cc @@ -34,12 +34,12 @@ void OVRTAllocator::Free(void* p) { try { ov::Tensor* tensor = nullptr; { - std::lock_guard lock(mutex_); - auto it = allocated_.find(p); - if (it != allocated_.end()) { - tensor = it->second; - allocated_.erase(it); - } + std::lock_guard lock(mutex_); + auto it = allocated_.find(p); + if (it != allocated_.end()) { + tensor = it->second; + allocated_.erase(it); + } } if (tensor) { delete tensor; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 0818f350562e9..3afe38ad12e71 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -364,7 +364,7 @@ StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, s } void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type, - const std::vector& shape, int32_t fill_value) { + const std::vector& shape, int32_t fill_value) { ov::Tensor tensor = ov::Tensor(type, shape); std::fill_n(tensor.data(), tensor.get_size(), fill_value); ovInfReq.set_tensor(tensor_name, tensor); @@ -379,7 +379,7 @@ void StatefulOVInferRequest::CacheTensor(const std::string& tensor_name, std::ve } void StatefulOVInferRequest::SetTensorFromCache(const std::string& tensor_name, - const std::vector& cache_data) { + const std::vector& cache_data) { auto tensor = ovInfReq.get_tensor(tensor_name); auto new_shape = tensor.get_shape(); new_shape[1] = cache_data.size(); diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index c3d165b40840c..82a8c27fa035c 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -14,6 +14,9 @@ #include "openvino/runtime/intel_npu/properties.hpp" #include "openvino/pass/convert_fp32_to_fp16.hpp" #include "openvino/frontend/manager.hpp" +#include "openvino/core/dimension.hpp" +#include "openvino/core/partial_shape.hpp" + #include namespace onnxruntime { @@ -129,7 +132,7 @@ class StatefulOVInferRequest : public OVInferRequest { void Infer() override; void RewindKVCache(size_t index) override; void FillTensor(const std::string& tensor_name, const ov::element::Type& type, - const std::vector& shape, int32_t fill_value); + const std::vector& shape, int32_t fill_value); void CacheTensor(const std::string& tensor_name, std::vector& cache); void SetTensorFromCache(const std::string& tensor_name, const std::vector& cache_data); std::optional FindTensor(const std::string& tensor_name); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 151fb0c3db0c0..abdb58f4f1801 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1165,63 +1165,13 @@ static std::shared_ptr CreateExecutionProviderFactory } else if (type == kOpenVINOExecutionProvider) { #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) ProviderOptions OV_provider_options_map; + const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", + "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", + "enable_causallm", "disable_dynamic_shapes", "reshape_input"}; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { for (auto option : it->second) { - if (option.first == "device_type") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "precision") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "enable_opencl_throttling") { - if (!(option.second == "True" || option.second == "true" || - option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for enable_opencl_throttling: ", option.second); - } - OV_provider_options_map[option.first] = option.second; - } else if (option.first == "disable_dynamic_shapes") { - if (!(option.second == "True" || option.second == "true" || - option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for disable_dynamic_shapes: ", option.second); - } - OV_provider_options_map[option.first] = option.second; - } else if (option.first == "enable_dynamic_shapes") { - LOGS_DEFAULT(WARNING) << " Deprecation notice - 'enable_dynamic_shapes' is Deprected. Upgrade the API to disable_dynamic_shapes parameter." - "Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met."; - std::string value; - if (!(option.second == "True" || option.second == "true" || - option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for enable_dynamic_shapes: ", option.second); - } - if (option.second == "True" || option.second == "true") { - value = "false"; - } else { - value = "true"; - } - OV_provider_options_map["disable_dynamic_shapes"] = value; - } else if (option.first == "num_of_threads") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "model_priority") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "num_streams") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "load_config") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "cache_dir") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "context") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "enable_qdq_optimizer") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "enable_causallm") { + if (valid_provider_keys.count(option.first)) { OV_provider_options_map[option.first] = option.second; continue; } else { @@ -2065,7 +2015,7 @@ for model inference.)pbdoc"); ORT_THROW("OrtEpDevices are not supported in this build"); #endif }, - R"pbdoc(Adds the execution provider that is responsible for the selected OrtEpDevice instances. All OrtEpDevice instances + R"pbdoc(Adds the execution provider that is responsible for the selected OrtEpDevice instances. All OrtEpDevice instances must refer to the same execution provider.)pbdoc") .def( // Equivalent to the C API's SessionOptionsSetEpSelectionPolicy. diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index c7a832420203d..174527118ce8b 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -874,13 +874,18 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") # if QDQ, CW and SYM enabled, optimize for Intel NPU, tranpose the weight to NHWC format will increase performance - qdq_opt_for_intel_npu_enabled = self.config.quant_format == QuantFormat.QDQ \ - and self.config.channel_wised_quantize and self.config.is_symmetric + qdq_opt_for_intel_npu_enabled = ( + self.config.quant_format == QuantFormat.QDQ + and self.config.channel_wised_quantize + and self.config.is_symmetric + ) if qdq_opt_for_intel_npu_enabled: rows, cols = b_ndarray.shape packed = transpose_packed_int4_matrix(packed, rows, cols) - scales = scales.reshape((cols, 1)) # (cols, 1) - b_quant = onnx.helper.make_tensor(b_tensor.name + f"_DQ_Q{bits}", qtype, [cols, rows], packed.tobytes(), True) + scales = scales.reshape((cols, 1)) # (cols, 1) + b_quant = onnx.helper.make_tensor( + b_tensor.name + f"_DQ_Q{bits}", qtype, [cols, rows], packed.tobytes(), True + ) scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") for input in b_graph.input: @@ -924,7 +929,10 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis dq_output_names = [b_quant.name + "_output"] tp_input_names = [dq_output_names[0]] tp_output_names = [dq_output_names[0] + "_transposed"] - matmul_input_names = [node.input[0], tp_output_names[0] if qdq_opt_for_intel_npu_enabled else dq_output_names[0]] + matmul_input_names = [ + node.input[0], + tp_output_names[0] if qdq_opt_for_intel_npu_enabled else dq_output_names[0], + ] matmul_output_names = [node.output[0]] if not self.config.is_symmetric: zp_tensor = onnx.helper.make_tensor( @@ -935,7 +943,7 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis rows, cols = b_ndarray.shape dq_kwargs = { "axis": 1 if qdq_opt_for_intel_npu_enabled else 0, - "block_size": rows if self.config.channel_wised_quantize else self.config.block_size + "block_size": rows if self.config.channel_wised_quantize else self.config.block_size, } dq_node = onnx.helper.make_node( "DequantizeLinear", @@ -955,7 +963,7 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis "Transpose", inputs=tp_input_names, outputs=tp_output_names, - perm=[1,0], + perm=[1, 0], ) output_nodes.extend([dq_node, tp_node, matmul_node]) else: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index e8eda5af1dc29..d036375874c4b 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -822,6 +822,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); device_memory_name_ = std::move(value); } else if (key == "device_luid") { ov_options[key] = value; + } else if (key == "reshape_input") { + ov_options[key] = value; } else { ORT_THROW( "[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO." diff --git a/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc b/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc index bf5d2d57727a6..e205b3aeb064a 100644 --- a/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc +++ b/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc @@ -25,27 +25,21 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; - extern std::unique_ptr ort_env; class OVEPEPContextTests : public ::testing::Test { - - }; namespace onnxruntime { namespace test { - // Test if folder path given to ep_context_file_path throws an error TEST_F(OVEPEPContextTests, OVEPEPContextFolderPath) { - Ort::SessionOptions sessionOptions; std::unordered_map ov_options; - //The below line could fail the test in non NPU platforms.Commenting it out so that the device used for building OVEP will be used. - //ov_options["device_type"] = "NPU"; - + // The below line could fail the test in non NPU platforms.Commenting it out so that the device used for building OVEP will be used. + // ov_options["device_type"] = "NPU"; const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; @@ -66,12 +60,10 @@ TEST_F(OVEPEPContextTests, OVEPEPContextFolderPath) { const std::string ep_context_file_path = "./ep_context_folder_path/"; - sessionOptions.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - sessionOptions.AddConfigEntry(kOrtSessionOptionEpContextFilePath,ep_context_file_path.c_str()); + sessionOptions.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_file_path.c_str()); sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); - try { Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), sessionOptions); FAIL(); // Should not get here! @@ -79,9 +71,7 @@ TEST_F(OVEPEPContextTests, OVEPEPContextFolderPath) { ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT); ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder.")); } - } - } // namespace test } // namespace onnxruntime From 9b245a45f43503824949596bd6ba827b9991a321 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Mon, 16 Jun 2025 12:17:03 +0530 Subject: [PATCH 046/138] [OVEP] Fix for appropriate device not selected (#696) --- .../openvino/openvino_provider_factory.cc | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 85594d1c70dd3..0b4e65f72fdf8 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -117,8 +117,6 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio luid_list = split(luid_str, ','); } - bool all_devices_found = true; - for (auto device : devices_to_check) { bool device_found = false; // Check deprecated device format (CPU_FP32, GPU.0_FP16, etc.) and remove the suffix in place @@ -137,6 +135,9 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio // Here we need to find the full device name (with .idx, but without _precision) if (std::find(std::begin(available_devices), std::end(available_devices), device) != std::end(available_devices)) device_found = true; + if (!device_found) { + ORT_THROW("[ERROR] [OpenVINO] Device ", device, " is not available"); + } if (device_prefix != "CPU" && luid_list.size() > 0) { for (const auto& dev : available_devices) { ov::device::LUID ov_luid = OVCore::Get()->core.get_property(dev, ov::device::luid); @@ -149,7 +150,6 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio ORT_THROW(msg); } } - all_devices_found = all_devices_found && device_found; } if (luid_list.size() > 0) { std::string ov_luid_devices; @@ -180,16 +180,9 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio selected_device = std::move(ov_luid_devices); } } - // If invalid device is chosen error is thrown - if (!all_devices_found) { - ORT_THROW( - "[ERROR] [OpenVINO] You have selected wrong configuration value for the key 'device_type'. " - "Select from 'CPU', 'GPU', 'NPU', 'GPU.x' where x = 0,1,2 and so on or from" - " HETERO/MULTI/AUTO/BATCH options available. \n"); - } else { - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; - return selected_device; - } + + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; + return selected_device; } void ParseProviderOptions([[maybe_unused]] ProviderInfo& result, [[maybe_unused]] const ProviderOptions& config_options) {} From 409b2243e9a5bc2f933886c7ff7abd06511af907 Mon Sep 17 00:00:00 2001 From: Bartlomiej Filipek Date: Mon, 16 Jun 2025 10:29:49 -0700 Subject: [PATCH 047/138] Improve the condition for device=GPU and qdq stripping enabled (#705) * make the condition stricter for gpu Signed-off-by: bfilipek * corrected form * improve the condition, allow int16 types when qdq stripping enabled Signed-off-by: bfilipek * revert wrongly commited file Signed-off-by: bfilipek --------- Signed-off-by: bfilipek --- .../core/providers/openvino/ov_versions/capability.cc | 5 ++++- onnxruntime/core/providers/openvino/ov_versions/data_ops.cc | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 46d2f6e02c70e..45ea822685710 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -34,10 +34,13 @@ GetCapability::GetCapability(const EPCtxHandler& ep_ctx_handler, graph_viewer_(graph_viewer_param), device_type_(std::move(device_type_param)) { bool npu_qdq_optimizer_enabled = false; - if (device_type_.find("NPU") != std::string::npos || device_type_.find("GPU") != std::string::npos) { + if (device_type_.find("NPU") != std::string::npos) { device_type_ = "CPU"; if (enable_qdq_optimizer) npu_qdq_optimizer_enabled = true; + } else if (enable_qdq_optimizer && device_type_.find("GPU") != std::string::npos) { + npu_qdq_optimizer_enabled = true; // see data_ops.cc ~615 where we check for int16 types for gpu, this may change to a better approach later } + #if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 5 data_ops_ = new DataOps(graph_viewer_, V_2024_5, device_type_, npu_qdq_optimizer_enabled); #elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 6 diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 4e1387d2ef4a9..b88f0d04d21f2 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -612,6 +612,9 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { (var.second == dtype)) { return true; } + // experimentally for GPU and qdq stripping mode allow int16 types + if (npu_qdq_optimizer_enabled_ && (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16)) + return true; } #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { From cbef617526d0384d8a6071250f8be8a64b57e525 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Wed, 18 Jun 2025 21:57:39 -0700 Subject: [PATCH 048/138] Optimize CPU time spent in inference path (continued) (#695) * Use infer instead of start async/wait * Introduce OvExeceptionBoundary for exception handling * unbound infer request pool * Fix dynamically sized i/o * Rename onnx->ort + remove unused parameter shape functions * fix linux build issue + review dog comments * more linux build fixes + copilot feedback * disable ReduceSum_noop_axes_input_initializer_opset_18 * review feedback + last minute touch ups * slightly more scalable llm handling * Simplify dynamic shape checks * add missing staged changes * Remove references to IO_BUFFER_ENABLED * Minor tweaks to InferRequestPool * remove unused mem_info * Move ParameterShape and ParameterInfo out of ov_interface --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../providers/openvino/backend_manager.cc | 38 +-- .../core/providers/openvino/backend_manager.h | 1 + .../core/providers/openvino/backend_utils.cc | 44 +-- .../core/providers/openvino/backend_utils.h | 51 ++- .../openvino/backends/basic_backend.cc | 306 ++++++------------ .../openvino/backends/basic_backend.h | 170 +++++----- .../core/providers/openvino/ibackend.h | 2 +- .../openvino/openvino_provider_factory.cc | 16 +- .../core/providers/openvino/ov_interface.cc | 158 +++------ .../core/providers/openvino/ov_interface.h | 29 +- .../cpu/reduction/reduction_ops_test.cc | 6 +- 11 files changed, 324 insertions(+), 497 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index cb7acfd2ca95a..684f94eed54c3 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -44,10 +44,6 @@ BackendManager::BackendManager(SessionContext& session_context, shared_context_{shared_context} { subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph); - bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos; - bool npu = session_context_.device_type.find("NPU") != std::string::npos; - subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) { // return empty if graph has no inputs or if types are not one of FP32/FP16 // else assume the type of the first input @@ -112,8 +108,7 @@ BackendManager::BackendManager(SessionContext& session_context, if (ModelHasSymbolicInputDims(subgraph)) { subgraph_context_.has_dynamic_input_shape = true; LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; - if (cpu_or_gpu || (npu && session_context_.enable_causallm) && - !session_context_.disable_dynamic_shapes) { + if (!session_context_.disable_dynamic_shapes) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; try { @@ -579,9 +574,7 @@ void BackendManager::ValidateInputShapes(const reshape_t& shapes, void BackendManager::Compute(OrtKernelContext* context) { Ort::KernelContext ctx(context); std::chrono::high_resolution_clock::time_point start_compute, end_compute; - bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos; - bool npu = session_context_.device_type.find("NPU") != std::string::npos; + #ifdef OPENVINO_FIL_ENABLED static bool fil_enabled = true; if (fil_enabled) { @@ -589,20 +582,26 @@ void BackendManager::Compute(OrtKernelContext* context) { LOGS_DEFAULT(INFO) << "Start Compute"; } #endif - // OV NPU doesn't support dynamic shaped model inference. + // if disable_dynamic_shapes is set to true then execution of dynamic model is done // by rewriting the model to static shaped model at runtime based on input shape. - // disable_dynamic_shapes is always set to true for OV NPU plugin. - if (subgraph_context_.has_dynamic_input_shape && - !session_context_.disable_dynamic_shapes && - (cpu_or_gpu || (npu && session_context_.enable_causallm))) { + // disable_dynamic_shapes should be set for devices that don't support dynamic shapes. + bool need_dynamic_backend = subgraph_context_.has_dynamic_input_shape && + session_context_.disable_dynamic_shapes; + + if (!need_dynamic_backend) { concrete_backend_->Infer(context); - } else if (subgraph_context_.has_dynamic_input_shape) { + } else { std::vector> tensor_shapes = GetInputTensorShapes(ctx); auto key = MakeMapKeyString(tensor_shapes, session_context_.device_type); std::shared_ptr dynamic_backend; - auto search = backend_map_.find(key); - if (search == backend_map_.end()) { + + { + std::unique_lock lock(mutex_); + dynamic_backend = backend_map_[key]; + } + + if (!dynamic_backend) { ptr_stream_t model_stream; LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " << "Creating dynamic backend for key: " << key; @@ -643,14 +642,11 @@ void BackendManager::Compute(OrtKernelContext* context) { } #endif } + std::unique_lock lock(mutex_); backend_map_.insert({key, dynamic_backend}); - } else { - dynamic_backend = search->second; } dynamic_backend->Infer(context); - } else { - concrete_backend_->Infer(context); } #ifdef OPENVINO_FIL_ENABLED if (fil_enabled) { diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index 7165b9cf2e14c..f091f95fe1c16 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -54,6 +54,7 @@ class BackendManager { std::unique_ptr model_proto_; std::shared_ptr concrete_backend_; + std::mutex mutex_; std::map> backend_map_; SubGraphContext subgraph_context_; EPCtxHandler& ep_ctx_handle_; diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 49eedfb3e4fcd..7598f7cfffba5 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -179,32 +179,6 @@ CreateOVModel(std::string&& model, } } -Ort::UnownedValue -GetOutputTensor(Ort::KernelContext& context, size_t batch_size, - OVInferRequestPtr infer_request, - std::string output_name, - const SubGraphContext::string_index_map_t& output_names) { - auto graph_output_blob = infer_request->GetTensor(output_name); - - auto graph_output_dims = graph_output_blob->get_shape(); - - if (batch_size > 1) { - // Add the batch size as dim 0. - graph_output_dims.insert(graph_output_dims.begin(), batch_size); - } - size_t num_dims = graph_output_dims.size(); - std::unique_ptr output_shape(new int64_t[num_dims]); - for (size_t j = 0; j < num_dims; j++) { - output_shape[j] = static_cast(graph_output_dims[j]); - } - auto it = output_names.find(output_name); - if (it == output_names.end()) { - ORT_THROW(log_tag + "Output names mismatch between OpenVINO and ONNX"); - } - int index = it->second; - return context.GetOutput(index, output_shape.get(), num_dims); -} - Ort::UnownedValue GetOutputTensor(Ort::KernelContext& context, std::string output_name, @@ -220,14 +194,9 @@ GetOutputTensor(Ort::KernelContext& context, ORT_THROW(log_tag + "Output names mismatch between OpenVINO and ONNX"); } int index = it->second; - auto shape = node->get_shape(); + auto output_shape = ParameterShape::ToOrtShape(node->get_shape()); - size_t num_dims = shape.size(); - std::unique_ptr output_shape(new int64_t[num_dims]); - for (size_t j = 0; j < num_dims; j++) { - output_shape[j] = static_cast(shape[j]); - } - return context.GetOutput(index, output_shape.get(), num_dims); + return context.GetOutput(index, output_shape); } int GetFirstAvailableDevice(SessionContext& session_context) { @@ -312,15 +281,6 @@ void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx, std::memcpy(input_data, batch_memory_offset, input_data_size); } -void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, - size_t batch_slice_idx) { - auto output_data = outputBlob->data(); - size_t output_data_size = outputBlob->get_byte_size(); - char* tensor_data = output_tensor.GetTensorMutableData(); - char* batch_memory_offset = tensor_data + output_data_size * batch_slice_idx; - std::memcpy(batch_memory_offset, output_data, output_data_size); -} - void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName) { int64_t totalTime = 0; diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index f13b1b05ced67..0e68d2f7526fd 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -27,8 +27,48 @@ namespace onnxruntime { namespace openvino_ep { +constexpr std::string log_tag = "[OpenVINO-EP] "; + +struct ParameterShape { + using ort_shape_t = std::vector; + + static ov::PartialShape ToOvPartialShape(const ort_shape_t& ort_shape) { + std::vector ov_shape(ort_shape.size()); + std::transform(ort_shape.begin(), ort_shape.end(), ov_shape.begin(), [](int64_t dim) { + return dim == -1 ? ov::Dimension::dynamic() : ov::Dimension(dim); + }); + return ov::PartialShape(ov_shape); + } + + static ort_shape_t ToOrtShape(const ov::PartialShape& ov_shape) { + ort_shape_t ort_shape(ov_shape.size()); + std::transform(ov_shape.begin(), ov_shape.end(), ort_shape.begin(), [](const auto& dim) { + return dim.is_dynamic() ? -1 : dim.get_length(); + }); + return ort_shape; + } + + static ort_shape_t ToOrtShape(const ov::Shape& ov_shape) { + ort_shape_t ort_shape(ov_shape.size()); + std::transform(ov_shape.begin(), ov_shape.end(), ort_shape.begin(), [](const auto& dim) { + return narrow(dim); + }); + return ort_shape; + } + + operator ov::Shape() const { return ov_.get_shape(); } + operator const ov::PartialShape&() const { return ov_; } + operator const ort_shape_t&() const { return ort_; } + + explicit ParameterShape(const ort_shape_t& ort_shape) : ort_(ort_shape), ov_(ToOvPartialShape(ort_shape)) {} + explicit ParameterShape(const ov::PartialShape& ov_partial_shape) : ov_(ov_partial_shape), ort_(ToOrtShape(ov_partial_shape)) {} + + private: + ort_shape_t ort_; + ov::PartialShape ov_; +}; + namespace backend_utils { -const std::string log_tag = "[OpenVINO-EP] "; bool IsDebugEnabled(); @@ -48,19 +88,10 @@ GetOutputTensor(Ort::KernelContext& context, const SubGraphContext::string_index_map_t& output_names, std::shared_ptr node); -Ort::UnownedValue -GetOutputTensor(Ort::KernelContext& context, size_t batch_size, - OVInferRequestPtr infer_request, - std::string output_name, - const SubGraphContext::string_index_map_t& output_names); - void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx, std::string input_name, Ort::KernelContext& context, const SubGraphContext& subgraph_context); -void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, - size_t batch_slice_idx); - std::shared_ptr CreateOVModel(std::string&& model, const SessionContext& session_context, diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 3105c307706ad..1b7ba1a1b5a82 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/backend_utils.h" @@ -128,7 +129,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr } }; } - inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, num_infer_req, std::move(initializer))); + infer_req_pool_ = std::make_unique(exe_network_, num_infer_req, std::move(initializer)); bindings_ = std::make_unique(exe_network_, subgraph_context_, session_context_); } @@ -379,170 +380,12 @@ void BasicBackend::ValidateOrtDimsAgainstPartialShape(const std::vector } void BasicBackend::RewindKVCache(size_t index) { - OVInferRequestPtr infer_request; - infer_request = inferRequestsQueue_->getIdleRequest(); - infer_request->RewindKVCache(index); - inferRequestsQueue_->putIdleRequest(std::move(infer_request)); + infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) { + infer_request->RewindKVCache(index); + }); } -// Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on -// an Infer Request indexed by infer_req_idx -void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { - try { - const bool is_cpu = session_context_.device_type.find("CPU") != std::string::npos; - const bool is_gpu = session_context_.device_type.find("GPU") != std::string::npos; - const bool is_npu = session_context_.device_type.find("NPU") != std::string::npos; - const bool is_cpu_or_gpu = is_cpu || is_gpu; - - // Loop over subgraph original input names to find the correspondent OV input name - for (const auto& input_info : bindings_->network_inputs_) { - size_t batch_slice_idx = 0; - auto tensor = context.GetInput(input_info.onnx_index); - auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); - auto tensor_shape = tensor_info.GetShape(); - auto tensor_data = tensor.GetTensorData(); - if (input_info.IsBoundedDynamic()) { - ov::PartialShape partial_shape = input_info.ov_shape; - ValidateOrtDimsAgainstPartialShape(tensor_shape, partial_shape); - } - ov::Shape input_tensor_shape(tensor_shape.begin(), tensor_shape.end()); - OVTensorPtr tensor_ptr; - if (is_cpu_or_gpu) { - if (input_info.IsStatic()) { - try { - auto graph_input_blob = infer_request->GetTensor(input_info.name); - FillInputBlob(std::move(graph_input_blob), batch_slice_idx, input_info.name, context, subgraph_context_); - } catch (const char* msg) { - ORT_THROW(msg); - } - } else { - if (is_cpu) { - tensor_ptr = std::make_shared(input_info.type, input_tensor_shape, (void*)tensor_data); - } else { // GPU - tensor_ptr = std::make_shared(input_info.type, input_tensor_shape); - FillInputBlob(tensor_ptr, batch_slice_idx, input_info.name, context, subgraph_context_); - } - - try { - infer_request->SetTensor(input_info.name, tensor_ptr); - } catch (const char* msg) { - ORT_THROW(msg); - } - } - } else { // Other device path - ort_tensor_key_t ort_tensor_key{input_info.name}; - auto it = ort_ov_tensor_map.find(ort_tensor_key); - - if (it == ort_ov_tensor_map.end() || it->second.ort_ptr != tensor.GetTensorRawData()) { - ov_tensor_data_t ov_tensor_data; - ov_tensor_data.tensor_ptr = std::make_shared(input_info.type, input_tensor_shape, - const_cast(tensor.GetTensorRawData())); - ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); - ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; - - try { - infer_request->SetTensor(input_info.name, ov_tensor_data.tensor_ptr); - } catch (const char* msg) { - ORT_THROW(msg); - } - } - } - } - // Handle output - if (is_npu && !session_context_.enable_causallm) { - // Set the output blob as remote blob - for (const auto& output_info : bindings_->network_outputs_) { - if (output_info.IsStatic()) { - // Set remote tensor for static outputs only - Ort::UnownedValue tensor = context.GetOutput(output_info.onnx_index, output_info.onnx_shape); - - ort_tensor_key_t ort_tensor_key{output_info.name}; - const auto& it = ort_ov_tensor_map.find(ort_tensor_key); - if ((it == ort_ov_tensor_map.end()) || (it->second.ort_ptr != tensor.GetTensorRawData())) { - ov_tensor_data_t ov_tensor_data; - ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); - ov_tensor_data.tensor_ptr = std::make_shared(output_info.type, output_info.ov_shape.get_shape(), - const_cast(tensor.GetTensorRawData())); - ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; - - try { - infer_request->SetTensor(output_info.name, ov_tensor_data.tensor_ptr); - } catch (const char* msg) { - ORT_THROW(msg); - } - } - } - } - } - - // Start Async inference - infer_request->StartAsync(); - } catch (const char* msg) { - ORT_THROW(msg); - } -} - -// Wait for asynchronous inference completion on an Infer Request object indexed by infer_req_idx -// and copy the results into a slice location within the batched output buffer indexed by batch_slice_idx -void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { - // Wait for Async inference completion - try { - infer_request->WaitRequest(); - } catch (const std::runtime_error& e) { - infer_request->CancelRequest(); - inferRequestsQueue_->deleteRequest(); - ORT_THROW(log_tag + e.what()); - } - - bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos; - bool npu = session_context_.device_type.find("NPU") != std::string::npos; - for (const auto& output_info : bindings_->network_outputs_) { - if (cpu_or_gpu || (npu && (session_context_.enable_causallm || !output_info.IsStatic()))) { - OVTensorPtr graph_output_blob; - try { - graph_output_blob = infer_request->GetTensor(output_info.name); - } catch (const char* msg) { - ORT_THROW(msg); - } - size_t batch_size = 1; - Ort::UnownedValue output_tensor = - GetOutputTensor(context, batch_size, infer_request, output_info.name, subgraph_context_.output_names); - auto mem_info = output_tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - return; - } else { - size_t batch_slice = 0; - FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); - } - } - } - - if (!const_outputs_map_.empty()) { - for (const auto& item : const_outputs_map_) { - const auto& out_name = item.first; - auto node = item.second; - try { - Ort::UnownedValue output_tensor = GetOutputTensor(context, - out_name, - subgraph_context_.output_names, - node); - auto mem_info = output_tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - ORT_THROW(log_tag + "IO Buffering is not supported for constant subgraphs"); - } else { - FillOutputsWithConstantData(std::move(node), output_tensor); - } - } catch (std::string const& msg) { - ORT_THROW(msg); - } - } - } -} - -void BasicBackend::Infer(OrtKernelContext* ctx) { - // Preliminary Thread safety mechanism - // currently allows a maximum of 8 Infer request's to parallel execute at the same time +void BasicBackend::Infer(OrtKernelContext* ctx) const { Ort::KernelContext context(ctx); LOGS_DEFAULT(INFO) << log_tag << "Running graph " << subgraph_context_.subgraph_name; @@ -552,74 +395,107 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { for (const auto& item : const_outputs_map_) { std::string out_name = item.first; std::shared_ptr node = item.second; - try { - Ort::UnownedValue output_tensor = GetOutputTensor(context, - std::move(out_name), - subgraph_context_.output_names, - node); - FillOutputsWithConstantData(std::move(node), output_tensor); - } catch (std::string const& msg) { - ORT_THROW(msg); - } + Ort::UnownedValue output_tensor = GetOutputTensor(context, + std::move(out_name), + subgraph_context_.output_names, + node); + FillOutputsWithConstantData(std::move(node), output_tensor); } - // Get Output tensors + LOGS_DEFAULT(INFO) << log_tag << "Inference successful"; - // Enable CI Logs + if (IsCILogEnabled()) { std::cout << "Inference successful" << std::endl; } + return; + } - } else { - OVInferRequestPtr infer_request; - infer_request = inferRequestsQueue_->getIdleRequest(); - if (infer_request == nullptr) { - ORT_THROW("OpenVINO Execution Provider :: There are no inference requests"); - LOGS_DEFAULT(FATAL) << log_tag << "Create Infer Requests do not exist"; - return; + // guarded_request will be released back to the pool when it goes out of scope + auto guarded_request = infer_req_pool_->getRequest(); + auto& infer_request = guarded_request.infer_request_; + + if (bindings_->has_dynamic_io_) { + // Dynamic shape inference + + // We don't know the output shapes so we need to get the outputs from the infer request and copy them into the ort + // tensors instead of binding them to the infer request directly. + + // Bind inputs + for (const auto& input_info : bindings_->network_inputs_) { + // Set the input shape based on the input tensor from ort + auto tensor = context.GetInput(input_info.onnx_index); + auto ort_shape = tensor.GetTensorTypeAndShapeInfo().GetShape(); + if (input_info.IsBoundedDynamic()) { + ValidateOrtDimsAgainstPartialShape(ort_shape, input_info.shape); + } + auto input_shape = ParameterShape(ort_shape); + + infer_request->SetTensor(input_info.name, + input_info.type, + input_shape, + const_cast(tensor.GetTensorRawData())); } - LOGS_DEFAULT(INFO) << log_tag << "Get Idle Request"; - try { - StartAsyncInference(context, infer_request); - } catch (const std::runtime_error& e) { - // If the inference fails (exception from ov::InferRequest::infer()), - // we need to put the infer_request back into the pool to avoid deadlocks - // and to allow the next inference request to proceed. - inferRequestsQueue_->putIdleRequest(std::move(infer_request)); - ORT_THROW(log_tag + " Exception at StartAsyncInference: " + e.what()); + // Run Inference + infer_request->Infer(); + + // Copy outputs + for (const auto& output_info : bindings_->network_outputs_) { + auto ov_tensor = infer_request->GetTensor(output_info.name); + auto output_shape = ParameterShape::ToOrtShape(ov_tensor->get_shape()); + auto ort_tensor = context.GetOutput(output_info.onnx_index, output_shape); + + ORT_ENFORCE(ov_tensor->get_byte_size() == ort_tensor.GetTensorSizeInBytes(), + log_tag + "Output tensor size mismatch for " + output_info.name); + + std::memcpy(ort_tensor.GetTensorMutableRawData(), + ov_tensor->data(), + ov_tensor->get_byte_size()); } - try { - CompleteAsyncInference(context, infer_request); - } catch (const std::runtime_error& e) { - // If the inference fails (exception from ov::InferRequest::infer()), - // we need to put the infer_request back into the pool to avoid deadlocks - // and to allow the next inference request to proceed. - inferRequestsQueue_->putIdleRequest(std::move(infer_request)); - ORT_THROW(log_tag + " Exception at CompleteAsyncInference: " + e.what()); + } else { + // Static shape inference + + // Bind inputs + for (const auto& input_info : bindings_->network_inputs_) { + infer_request->SetTensor(input_info.name, + input_info.type, + input_info.shape, + const_cast(context.GetInput(input_info.onnx_index).GetTensorRawData())); } - // Get Output tensors - LOGS_DEFAULT(INFO) << log_tag << "Inference successful"; - // Enable CI Logs - if (IsCILogEnabled()) { - std::cout << "Inference successful" << std::endl; + // Bind outputs + for (const auto& output_info : bindings_->network_outputs_) { + infer_request->SetTensor(output_info.name, + output_info.type, + output_info.shape, + context.GetOutput(output_info.onnx_index, output_info.shape).GetTensorMutableRawData()); } - // Create a duplicate infer_request_ shared ptr on the stack in the current local scope, - // as the infer_request gets freed in the next stage the reference count for the infer_request decrements & - // thus we dont have any dangling ptr leading to seg faults in the debug mode subsequent execution call - OVInferRequestPtr infer_request_ = infer_request; + // Run Inference + infer_request->Infer(); + } + + // Fill constant outputs if needed + for (const auto& [name, node] : const_outputs_map_) { + Ort::UnownedValue output_tensor = GetOutputTensor(context, + name, + subgraph_context_.output_names, + node); + FillOutputsWithConstantData(node, output_tensor); + } + + LOGS_DEFAULT(INFO) << log_tag << "Inference successful"; + if (IsCILogEnabled()) { + std::cout << "Inference successful" << std::endl; + } - // Once the inference is completed, the infer_request becomes free and is placed back into pool of infer_requests_ - inferRequestsQueue_->putIdleRequest(std::move(infer_request)); #ifndef NDEBUG - if (openvino_ep::backend_utils::IsDebugEnabled()) { - inferRequestsQueue_->printstatus(); // Printing the elements of infer_requests_ vector pool only in debug mode - std::string& hw_target = session_context_.device_type; - printPerformanceCounts(std::move(infer_request_), std::cout, hw_target); - } -#endif + // Print performance counts before releasing the infer_request for thread safety + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::string& hw_target = session_context_.device_type; + printPerformanceCounts(infer_request, std::cout, hw_target); } +#endif } } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 8e76c9e69e223..b1d5406fcf3e2 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -25,56 +25,59 @@ namespace onnxruntime { namespace openvino_ep { -struct ov_tensor_data_t { - OVTensorPtr tensor_ptr; - const void* ort_ptr; +struct ParameterInfo { + std::string name; + uint32_t ov_index; + uint32_t onnx_index; + ov::element::Type type; + ParameterShape shape; + uint8_t dynamic_flags = 0; + + // Query methods + bool IsStatic() const { return dynamic_flags == 0; } + bool IsFullyDynamic() const { return dynamic_flags & 1; } + bool IsBoundedDynamic() const { return dynamic_flags & 2; } + bool IsMixed() const { return (dynamic_flags & 3) == 3; } + + // Setter methods + void SetFullyDynamic(bool value) { + dynamic_flags = value ? (dynamic_flags | 1) : (dynamic_flags & ~1); + } + void SetBoundedDynamic(bool value) { + dynamic_flags = value ? (dynamic_flags | 2) : (dynamic_flags & ~2); + } }; struct OnnxToOvNetworkBindings { - struct ParameterInfo { - std::string name; - uint32_t ov_index; - uint32_t onnx_index; - ov::element::Type type; - ov::PartialShape ov_shape; - std::vector onnx_shape; - uint8_t dynamic_flags = 0; // bit 0: fully_dynamic, bit 1: bounded_dynamic - - // Query methods - bool IsStatic() const { return dynamic_flags == 0; } - bool IsFullyDynamic() const { return dynamic_flags & 1; } - bool IsBoundedDynamic() const { return dynamic_flags & 2; } - bool IsMixed() const { return (dynamic_flags & 3) == 3; } - - // Setter methods - void SetFullyDynamic(bool value) { - dynamic_flags = value ? (dynamic_flags | 1) : (dynamic_flags & ~1); - } - void SetBoundedDynamic(bool value) { - dynamic_flags = value ? (dynamic_flags | 2) : (dynamic_flags & ~2); - } - }; - std::vector network_outputs_; std::vector network_inputs_; + bool has_dynamic_io_ = false; + + inline static const std::array special_io_names_{ + "beam_idx", + "past_key_values", + "present", + }; OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context, SessionContext& session_context) { auto populate = [&](auto& input_output_map, const SubGraphContext::string_index_map_t& onnx_input_map, const auto& ov_parameters) { for (const auto& [onnx_name, onnx_param_index] : onnx_input_map) { auto it = std::find_if(ov_parameters.begin(), ov_parameters.end(), [&onnx_name](const auto& ov_parameter_info) { return ov_parameter_info.get_names().contains(onnx_name); }); + bool matched_names = it != ov_parameters.end(); // For Stateful Model Compilation, the ONNX model includes KV cache (past/present) tensors. // However, these tensors are internally converted to a stateful representation, which removes them. // To prevent runtime exceptions, we simply continue processing here. - if ((onnx_name.empty() || onnx_name == "beam_idx" || - onnx_name.find("past_key_values") != std::string::npos || - onnx_name.find("present") != std::string::npos) && - session_context.enable_causallm) { + if (!matched_names && session_context.enable_causallm && + std::any_of(special_io_names_.begin(), special_io_names_.end(), + [&onnx_name](const std::string& name) { return onnx_name.find(name) != std::string::npos; })) { + // This case also requires dynamic shape inference, so we'll mark the bindings as dynamic. + has_dynamic_io_ = true; continue; } - ORT_ENFORCE(it != ov_parameters.end(), backend_utils::log_tag, + ORT_ENFORCE(matched_names, log_tag, "Input names mismatch between OpenVINO and ONNX. ", onnx_name, " doesn't exist in the list of OpenVINO input tensor names"); @@ -82,15 +85,11 @@ struct OnnxToOvNetworkBindings { auto shape = ov_parameters[ov_param_index].get_partial_shape(); auto type = ov_parameters[ov_param_index].get_element_type(); - ParameterInfo info{onnx_name, ov_param_index, onnx_param_index, type, shape}; + ParameterInfo info{onnx_name, ov_param_index, onnx_param_index, type, ParameterShape{shape}}; // Analyze shape dynamism and set flags - if (shape.is_static()) { - // dynamic_flags remains 0 (static) - auto static_shape = shape.get_shape(); - std::transform(static_shape.begin(), static_shape.end(), std::back_inserter(info.onnx_shape), - [](const auto& dim) { return static_cast(dim); }); - } else { + if (!shape.is_static()) { + has_dynamic_io_ = true; // Analyze dynamic dimensions bool has_fully_dynamic = false; bool has_bounded_dynamic = false; @@ -118,7 +117,8 @@ struct OnnxToOvNetworkBindings { populate(network_outputs_, subgraph_context.output_names, exec_network.Get().outputs()); } }; -class InferRequestsQueue; + +class InferRequestPool; class BasicBackend : public IBackend { public: BasicBackend(std::unique_ptr& model_proto, @@ -127,7 +127,7 @@ class BasicBackend : public IBackend { SharedContext& shared_context, ptr_stream_t& model_stream); - void Infer(OrtKernelContext* context) override; + void Infer(OrtKernelContext* context) const override; ~BasicBackend() override = default; ov::CompiledModel GetOVCompiledModel() override { return exe_network_.Get(); @@ -141,79 +141,81 @@ class BasicBackend : public IBackend { void EnableGPUThrottling(ov::AnyMap& device_config); void EnableStreams(); void SetNumThreads(ov::AnyMap& device_config); - void StartAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); void ValidateOrtDimsAgainstPartialShape(const std::vector& ort_dims, const ov::PartialShape& partial_shape) const; - void CompleteAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); SessionContext& session_context_; SubGraphContext subgraph_context_; SharedContext& shared_context_; - mutable std::mutex compute_lock_; OVExeNetwork exe_network_; std::map> const_outputs_map_; - std::unique_ptr inferRequestsQueue_; + std::unique_ptr infer_req_pool_; + using ort_tensor_key_t = const std::string; - std::map ort_ov_tensor_map; - std::unique_ptr bindings_; + std::unique_ptr bindings_; }; -class InferRequestsQueue { +class InferRequestPool { public: - InferRequestsQueue(OVExeNetwork& net, size_t nireq, std::function initializer) { - OVInferRequestPtr infer_request; - live_threads = nireq; - for (size_t id = 0; id < nireq; id++) { - infer_request = net.CreateInferRequest(); - initializer(infer_request); - infer_requests_.push_back(infer_request); - } - } + struct GuardedInferReq { + OVInferRequestPtr infer_request_; + GuardedInferReq(InferRequestPool& queue, OVInferRequestPtr&& infer_req) : queue_(queue), infer_request_(std::move(infer_req)) {} + ~GuardedInferReq() { queue_.putIdleRequest(std::move(infer_request_)); } + + // Movable but not copyable + ORT_DISALLOW_COPY_AND_ASSIGNMENT(GuardedInferReq); + GuardedInferReq(GuardedInferReq&&) = default; + GuardedInferReq& operator=(GuardedInferReq&&) = default; + + private: + InferRequestPool& queue_; + friend class InferRequestPool; + }; - ~InferRequestsQueue() { - // clearing out the infer_requests_ vector pool in the class's destructor - for (auto& pointer : infer_requests_) { - pointer = nullptr; + InferRequestPool(OVExeNetwork& net, size_t initial_size, std::function initializer) : exe_network_(net), initializer_(std::move(initializer)) { + for (size_t id = 0; id < initial_size; id++) { + infer_requests_.emplace_back(createInferRequest()); } - infer_requests_.erase(std::remove(infer_requests_.begin(), infer_requests_.end(), nullptr), infer_requests_.end()); } + ~InferRequestPool() = default; - void printstatus() { - std::cout << "printing elements of the vector (infer_requests_): " << std::endl; - for (auto i = infer_requests_.begin(); i != infer_requests_.end(); ++i) { - i->get()->QueryStatus(); + GuardedInferReq getRequest() { + std::unique_lock lock(_mutex); + if (infer_requests_.empty()) { + infer_requests_.emplace_back(createInferRequest()); } - std::cout << '\n'; + auto request = std::move(infer_requests_.back()); + infer_requests_.pop_back(); + return GuardedInferReq(*this, std::move(request)); } - void putIdleRequest(OVInferRequestPtr infer_request) { + template + void forEachIdleRequest(Func&& func) { std::unique_lock lock(_mutex); - infer_requests_.push_back(infer_request); - _cv.notify_one(); + for (auto& infer_request : infer_requests_) { + func(infer_request); + } } - OVInferRequestPtr getIdleRequest() { - std::unique_lock lock(_mutex); - if (live_threads == 0) { - return nullptr; + private: + void putIdleRequest(OVInferRequestPtr&& infer_request) { + if (infer_request) { + std::unique_lock lock(_mutex); + infer_requests_.emplace_back(std::move(infer_request)); } - - _cv.wait(lock, [this] { return infer_requests_.size() > 0; }); - auto request = infer_requests_.at(0); - infer_requests_.erase(infer_requests_.begin()); - return request; } - void deleteRequest() { - std::unique_lock lock(_mutex); - live_threads = live_threads - 1; + OVInferRequestPtr createInferRequest() { + auto infer_request = exe_network_.CreateInferRequest(); + initializer_(infer_request); + return infer_request; } private: std::mutex _mutex; - std::condition_variable _cv; std::vector infer_requests_; - int live_threads; + OVExeNetwork& exe_network_; + std::function initializer_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h index 752668b3c6fbe..ec38425f602eb 100644 --- a/onnxruntime/core/providers/openvino/ibackend.h +++ b/onnxruntime/core/providers/openvino/ibackend.h @@ -14,7 +14,7 @@ namespace openvino_ep { class IBackend { public: - virtual void Infer(OrtKernelContext* context) = 0; + virtual void Infer(OrtKernelContext* context) const = 0; virtual ov::CompiledModel GetOVCompiledModel() = 0; virtual ~IBackend() = default; virtual void RewindKVCache(size_t index) {} diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 0b4e65f72fdf8..bad1d416eeda2 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -347,14 +347,14 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, ORT_THROW(msg); } - if (pi.device_type.find("NPU") != std::string::npos) { - // For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path. - if (pi.enable_causallm) { - pi.disable_dynamic_shapes = false; - } else { - pi.disable_dynamic_shapes = true; - } - } + // Should likely account for meta devices as well, but for now keep the current behavior. + bool target_devices_support_dynamic_shapes = + pi.device_type.find("GPU") != std::string::npos || + pi.device_type.find("CPU") != std::string::npos || + (pi.device_type.find("NPU") != std::string::npos && + pi.enable_causallm); + + pi.disable_dynamic_shapes = !target_devices_support_dynamic_shapes; } struct OpenVINOProviderFactory : IExecutionProviderFactory { diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 3afe38ad12e71..38b5f9a52eb3e 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -3,6 +3,8 @@ #include "core/providers/openvino/ov_interface.h" +#include + #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/shared_library/provider_api.h" @@ -10,12 +12,19 @@ #include "core/providers/openvino/backends/basic_backend.h" #include "core/providers/openvino/ov_stateful_patch_utils.h" -using Exception = ov::Exception; - namespace onnxruntime { namespace openvino_ep { -static const std::string log_tag = "[OpenVINO-EP] "; +template +inline auto OvExceptionBoundary(Func &&func, std::format_string&& fmt, Args&&... args) { + try { + return func(); + } catch (const ov::Exception& e) { + ORT_THROW(log_tag + std::vformat(fmt.get(), std::make_format_args(args...)) + ": " + std::string(e.what())); + } catch (...) { + ORT_THROW(log_tag + std::vformat(fmt.get(), std::make_format_args(args...))); + } +} #ifndef NDEBUG void printDebugInfo(const ov::CompiledModel& obj) { @@ -60,7 +69,7 @@ std::optional queryOVProperty(const std::string& property, const std::stri } std::shared_ptr OVCore::ReadModel(std::string&& model, const std::string& model_path) { - try { + return OvExceptionBoundary([&]() { std::istringstream modelStringStream(std::move(model)); std::istream& modelStream = modelStringStream; // Try to load with FrontEndManager @@ -75,13 +84,10 @@ std::shared_ptr OVCore::ReadModel(std::string&& model, const std::str inputModel = FE->load(params); return FE->convert(inputModel); } else { - ORT_THROW(log_tag + "[OpenVINO-EP] Unknown exception while Reading network"); + ORT_THROW(log_tag + "Unknown exception while Reading network"); } - } catch (const Exception& e) { - ORT_THROW(log_tag + "[OpenVINO-EP] Exception while Reading network: " + std::string(e.what())); - } catch (...) { - ORT_THROW(log_tag + "[OpenVINO-EP] Unknown exception while Reading network"); - } + }, + "Exception while Reading network"); } OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr& model, @@ -149,14 +155,14 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_netwo ov::AnyMap& device_config, bool enable_causallm, const std::string& name) { - OVExeNetwork exe; - try { + return OvExceptionBoundary([&]() { + OVExeNetwork exe; if (enable_causallm) { - auto mutable_model = ie_cnn_network->clone(); - exe = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config); + auto mutable_model = ie_cnn_network->clone(); + exe = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config); } else { - auto obj = core.compile_model(ie_cnn_network, hw_target, device_config); - exe = OVExeNetwork(obj, hw_target); + auto obj = core.compile_model(ie_cnn_network, hw_target, device_config); + exe = OVExeNetwork(obj, hw_target); } #ifndef NDEBUG @@ -164,37 +170,32 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_netwo #endif return exe; - } catch (const Exception& e) { - ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Exception while Loading Network for graph " + name); - } + }, + "Exception while Loading Network for graph {}", name); } OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, std::string& hw_target, ov::AnyMap& device_config, const std::string& name) { - ov::CompiledModel obj; - try { + return OvExceptionBoundary([&]() { + ov::CompiledModel obj; + obj = core.compile_model(onnx_model, ov::Tensor(), hw_target, device_config); #ifndef NDEBUG printDebugInfo(obj); #endif OVExeNetwork exe(obj, hw_target); return exe; - } catch (const Exception& e) { - ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Exception while Loading Network for graph " + name); - } + }, + "Exception while Loading Network for graph {}", name); } OVExeNetwork OVCore::ImportModel(std::istream& model_stream, std::string hw_target, const ov::AnyMap& device_config, std::string name) { - try { + return OvExceptionBoundary([&]() { ov::CompiledModel obj; obj = core.import_model(model_stream, hw_target, device_config); #ifndef NDEBUG @@ -202,11 +203,8 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream, #endif OVExeNetwork exe(obj, hw_target); return exe; - } catch (const Exception& e) { - ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Exception while Loading Network for graph " + name); - } + }, + "Exception while Loading Network for graph {}", name); } void OVCore::SetCache(const std::string& cache_dir_path) { @@ -227,20 +225,13 @@ std::vector OVCore::GetAvailableDevices(const std::string& device_t } catch (const ov::Exception&) { // plugin is not created by e.g. invalid env // Empty device list will be returned - } catch (const std::runtime_error& ex) { - // plugin is not created by e.g. invalid env - // Empty device list will be returned - ORT_THROW("[ERROR] [OpenVINO] An exception occurred while trying to create the ", - device_type, - " device: ", - ex.what()); } catch (const std::exception& ex) { - ORT_THROW("[ERROR] [OpenVINO] An exception occurred while trying to create the ", + ORT_THROW(log_tag + "An exception occurred while trying to create the ", device_type, " device: ", ex.what()); } catch (...) { - ORT_THROW("[ERROR] [OpenVINO] Unknown exception occurred while trying to create the ", + ORT_THROW(log_tag + "Unknown exception occurred while trying to create the ", device_type, " device"); } @@ -263,7 +254,7 @@ void OVCore::SetStreams(const std::string& device_type, int num_streams) { } std::shared_ptr OVExeNetwork::CreateInferRequest() { - try { + return OvExceptionBoundary([&]() { auto infReq = compiled_model_obj.create_infer_request(); std::shared_ptr ovInfReq; if (is_stateful_causallm) { @@ -272,87 +263,44 @@ std::shared_ptr OVExeNetwork::CreateInferRequest() { ovInfReq = std::make_shared(std::move(infReq)); } return ovInfReq; - } catch (const Exception& e) { - ORT_THROW(log_tag + "Exception while creating InferRequest object: " + e.what()); - } catch (...) { - ORT_THROW(log_tag + "Exception while creating InferRequest object."); - } + }, + + "Exception while creating InferRequest object"); } OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) { - try { + return OvExceptionBoundary([&]() { auto tobj = ovInfReq.get_tensor(input_name); OVTensorPtr blob = std::make_shared(tobj); return blob; - } catch (const Exception& e) { - ORT_THROW(log_tag + " Cannot access IE Blob for input: " + input_name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Cannot access IE Blob for input: " + input_name); - } + }, + " Cannot access IE Blob for input: {}", input_name); } std::string OVInferRequest::GetInputTensorName(uint32_t index) { - try { + return OvExceptionBoundary([&]() { const auto& model = ovInfReq.get_compiled_model(); return *model.input(index).get_names().begin(); - } catch (const Exception& e) { - ORT_THROW(log_tag + " Cannot access IE Blob for input number: ", index, e.what()); - } catch (...) { - ORT_THROW(log_tag + " Cannot access IE Blob for input number: ", index); - } + }, + " Cannot access IE Blob for input number: {}", index); } void OVInferRequest::SetTensor(const std::string& name, OVTensorPtr& blob) { - try { + OvExceptionBoundary([&]() { ovInfReq.set_tensor(name, *(blob.get())); - } catch (const Exception& e) { - ORT_THROW(log_tag + " Cannot set Remote Blob for output: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Cannot set Remote Blob for output: " + name); - } + }, + " Cannot set Remote Blob for output: {}", name); } uint32_t OVInferRequest::GetNumInputs() { return static_cast(ovInfReq.get_compiled_model().inputs().size()); } -void OVInferRequest::StartAsync() { - try { - ovInfReq.start_async(); - } catch (const Exception& e) { - throw std::runtime_error(log_tag + " Couldn't start Inference: " + e.what()); - } catch (...) { - throw std::runtime_error(log_tag + " In Error Couldn't start Inference"); - } -} - void OVInferRequest::Infer() { - try { + OvExceptionBoundary([&]() { ovInfReq.infer(); - } catch (const Exception& e) { - throw std::runtime_error(log_tag + " Couldn't start Inference: " + e.what()); - } catch (...) { - throw std::runtime_error(log_tag + " In Error Couldn't start Inference"); - } -} - -void OVInferRequest::WaitRequest() { - ovInfReq.wait(); -} - -void OVInferRequest::CancelRequest() { - try { - ovInfReq.cancel(); - } catch (const Exception& e) { - ORT_THROW(log_tag + " Cancel Model Failed: " + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Cancel Mode Failed"); - } -} - -void OVInferRequest::QueryStatus() { - std::cout << "ovInfReq.query_state()" - << " "; + }, + "In Error Couldn't start Inference"); } StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) @@ -449,11 +397,6 @@ void StatefulOVInferRequest::PreProcessInferRequest() { } } -void StatefulOVInferRequest::StartAsync() { - PreProcessInferRequest(); - OVInferRequest::StartAsync(); -} - void StatefulOVInferRequest::Infer() { PreProcessInferRequest(); OVInferRequest::Infer(); @@ -508,6 +451,5 @@ void StatefulOVInferRequest::RewindKVCache(size_t index) { } } } - } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 82a8c27fa035c..581da59bb4cae 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include "openvino/openvino.hpp" #include "openvino/runtime/intel_npu/properties.hpp" @@ -30,6 +32,7 @@ typedef ov::ProfilingInfo OVProfilingInfo; typedef ov::Model OVNetwork; typedef std::shared_ptr OVInferRequestPtr; typedef std::shared_ptr OVTensorPtr; + std::optional queryOVProperty(const std::string& property, const std::string& device_type); template @@ -103,20 +106,33 @@ class OVExeNetwork { }; class OVInferRequest { - protected: + struct ov_tensor_data_t { + OVTensorPtr tensor_ptr; + const void* ort_ptr; + }; + + protected: ov::InferRequest ovInfReq; + std::unordered_map bindings_cache_; public: uint32_t GetNumInputs(); OVTensorPtr GetTensor(const std::string& name); std::string GetInputTensorName(uint32_t index); + + // Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set. + void SetTensor(const std::string& name, const ov::element::Type &type, const ov::Shape& shape, void* ort_ptr) { + auto& cached_binding = bindings_cache_[name]; + if (cached_binding.ort_ptr != ort_ptr) { + auto tensor_ptr = std::make_shared(type, shape, const_cast(ort_ptr)); + SetTensor(name, tensor_ptr); + cached_binding = {tensor_ptr, ort_ptr}; + } + } + void SetTensor(const std::string& name, OVTensorPtr& blob); - virtual void StartAsync(); virtual void Infer(); - void WaitRequest(); - void CancelRequest(); - void QueryStatus(); - explicit OVInferRequest(ov::InferRequest infer_request_obj) : ovInfReq(std::move(infer_request_obj)) {} + explicit OVInferRequest(ov::InferRequest obj) : ovInfReq(std::move(obj)) {} OVInferRequest() : ovInfReq(ov::InferRequest()) {} ov::InferRequest& GetNewObj() { return ovInfReq; @@ -128,7 +144,6 @@ class StatefulOVInferRequest : public OVInferRequest { public: explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device); - void StartAsync() override; void Infer() override; void RewindKVCache(size_t index) override; void FillTensor(const std::string& tensor_name, const ov::element::Type& type, diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 92cd82c2c9420..e2ee859fb26df 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -4099,7 +4099,11 @@ TEST(ReductionOpTest, ReduceSum_noop_axes_input_initializer_opset_18) { 3.0f, 4.0f}); test.AddInput("axes", {0}, {}, true); test.AddOutput("reduced", {1, 2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); - test.Run(); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + {kOpenVINOExecutionProvider} // OpenVINO: Disabled temporarily + ); } TEST(ReductionOpTest, ReduceSum_empty_axes_input_initializer_opset_18) { From 6d04a2ecf9667bf5a14a85d54208dc34acbb63e8 Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Tue, 24 Jun 2025 19:41:48 +0530 Subject: [PATCH 049/138] feat: Enable EpContext OVIR Encapsulation (#704) * feat: Enable EpContext OVIR Encapsulation * fix: refactor EpCtx OVIR parsing logic to use ep.context_file_path * fix: Fix logic for parsing model_file_path * fix: enable EPCtx OVIR encapsulation compiled blob caching * fix: fix merge conflicts * fix: fix bugs --- .../providers/openvino/backend_manager.cc | 8 ++- .../core/providers/openvino/backend_utils.cc | 27 +++++++++ .../core/providers/openvino/backend_utils.h | 2 + .../openvino/backends/basic_backend.cc | 38 ++++++++++-- .../core/providers/openvino/contexts.h | 1 + .../openvino/onnx_ctx_model_helper.cc | 39 +++++++++++- .../openvino/onnx_ctx_model_helper.h | 1 + .../core/providers/openvino/ov_interface.cc | 59 ++++++++++++++++++- .../core/providers/openvino/ov_interface.h | 6 ++ 9 files changed, 169 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 684f94eed54c3..8887b183c4396 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -43,6 +43,9 @@ BackendManager::BackendManager(SessionContext& session_context, session_context_(session_context), shared_context_{shared_context} { subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph); + // If the graph contains a OVIR wrapped node, we check if it has matching xml file name attribute + subgraph_context_.is_ep_ctx_ovir_encapsulated = ep_ctx_handle_.CheckEPCacheContextAttribute(subgraph, + session_context_.onnx_model_path_name.filename().replace_extension("xml").string()); subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) { // return empty if graph has no inputs or if types are not one of FP32/FP16 @@ -192,9 +195,10 @@ BackendManager::BackendManager(SessionContext& session_context, } } } - if (session_context_.so_context_enable && !subgraph_context_.is_ep_ctx_graph) { + if (session_context_.so_context_enable && + (subgraph_context_.is_ep_ctx_ovir_encapsulated || !subgraph_context_.is_ep_ctx_graph)) { auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph); - if ((!status.IsOK())) { + if (!status.IsOK()) { ORT_THROW(status); } } diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 7598f7cfffba5..73fbe9a0fa76f 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -400,6 +400,33 @@ void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) metadata_map.clear(); } +bool IsModelStreamXML(std::istream& model_stream) { + std::streampos originalPos = model_stream.tellg(); + + // first, get the total size of model_stream in bytes + model_stream.seekg(0, std::ios::end); + auto end_pos = model_stream.tellg(); + // Restore the stream position + model_stream.seekg(originalPos); + auto total_size = end_pos - originalPos; + + // Choose 32 bytes to hold content of: + // ' header_check_len); + + // read 32 bytes into header + std::string header(header_check_len, '\0'); + model_stream.read(&header[0], header_check_len); + // Clear any read errors + model_stream.clear(); + // Restore the stream position + model_stream.seekg(originalPos); + + // return true if the header starts with '& performanceMap, void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName); +bool IsModelStreamXML(std::istream& model_stream); + } // namespace backend_utils } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 1b7ba1a1b5a82..00a18bb0a45b6 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -72,12 +72,38 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr !session_context_.so_disable_cpu_ep_fallback && !subgraph_context_.is_ep_ctx_graph); if (subgraph_context_.is_ep_ctx_graph) { - // If the blob is held in an EPContext node, then skip FE+Compile - // and directly move on to creating a backend with the executable blob - exe_network_ = OVCore::Get()->ImportModel(*model_stream, - hw_target, - device_config, - subgraph_context_.subgraph_name); + if (subgraph_context_.is_ep_ctx_ovir_encapsulated) { + // model_file_path will use so_context_file_path if the onnx_model_path_name is not available, + // especially in case of CreateSessionFormArray() where user must explicitly + // specify absolute path for so_context_file_path. + auto model_file_path = [this]() { + if (!session_context_.onnx_model_path_name.empty() && + std::filesystem::exists(session_context_.onnx_model_path_name)) return session_context_.onnx_model_path_name; + + ORT_ENFORCE(!session_context_.so_context_file_path.empty() && + std::filesystem::path(session_context_.so_context_file_path).is_absolute() && + std::filesystem::exists(session_context_.so_context_file_path), log_tag + + "Context file path must be non-empty & absolute, when using CreateSessionFormArray() API explicitly." + " Please set a valid absolute path for ep.context_file_path in session options."); + // Return absolute context file path as input to ImportEPCtxOVIREncapsulation() function. + return session_context_.so_context_file_path; + + }; + // If the EPContext node with OVIR Encapsulation, then create + // an executable network from EP_CACHE_CONTEXT using read_model() & compile_model() + exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream, + hw_target, + device_config, + enable_causallm, + model_file_path()); + } else { + // If the blob is held in an EPContext node, then skip FE+Compile + // and directly move on to creating a backend with the executable blob + exe_network_ = OVCore::Get()->ImportModel(*model_stream, + hw_target, + device_config, + subgraph_context_.subgraph_name); + } model_stream.reset(); // Delete stream after it is no longer needed } else if (!session_context_.has_external_weights && !subgraph_context_.has_dynamic_input_shape && diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 09d48a5e916e1..e2369cf728ea6 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -137,6 +137,7 @@ struct SubGraphContext { string_index_map_t output_names; std::string model_precision; bool is_ep_ctx_graph = false; + bool is_ep_ctx_ovir_encapsulated = false; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 7bd4f8d96cc55..49a4cb0a7e95a 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -7,6 +7,7 @@ #include #include "core/providers/openvino/onnx_ctx_model_helper.h" +#include "core/providers/openvino/backend_utils.h" namespace onnxruntime { namespace openvino_ep { @@ -123,6 +124,16 @@ std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesy ORT_ENFORCE(std::filesystem::exists(blob_filepath), "Blob file not found: ", blob_filepath.string()); result.reset((std::istream*)new std::ifstream(blob_filepath, std::ios_base::binary | std::ios_base::in)); } + + bool isXML = backend_utils::IsModelStreamXML(*result); + if (!isXML) { + // If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was + // exported with must match the version that is currently running. + ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_), + "EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() + + ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); + } + LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; return result; } @@ -142,7 +153,6 @@ bool EPCtxHandler::CheckForOVEPCtxNode(const Node& node) const { if (node.OpType() == EPCONTEXT_OP) { auto& attrs = node.GetAttributes(); bool result = (attrs.count(SOURCE) == 1) && (attrs.at(SOURCE).s() == kOpenVINOExecutionProvider); - result &= (attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_); result &= attrs.count(EMBED_MODE) == 1; result &= attrs.count(EP_CACHE_CONTEXT) == 1; return result; @@ -155,5 +165,32 @@ InlinedVector EPCtxHandler::GetEPCtxNodes() const { return InlinedVector(epctx_nodes.begin(), epctx_nodes.end()); } +// Check if graph's only node is EPContext & EP_CACHE_CONTEXT attribute has target extension. +// @param graph_viewer: The graph to inspect. +// @param target_attr_extn: The string to search for in the EP_CACHE_CONTEXT attribute. +// @return true if the node exists, is of the correct type, and the attribute contains the extension; false otherwise. +bool EPCtxHandler::CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const { + // Only check if the graph has exactly one node + if (graph_viewer.NumberOfNodes() != 1) { + return false; + } + // Get the first node in topological order + auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); + const Node* node = graph_viewer.GetNode(first_index); + if (!node) { + return false; + } + // Check OpType and required attributes + if (node->OpType() != EPCONTEXT_OP) { + return false; + } + const auto& attrs = node->GetAttributes(); + auto it = attrs.find(EP_CACHE_CONTEXT); + if (it != attrs.end()) { + return it->second().s().find(target_attr_extn) != std::string::npos; + } + return false; +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index ff978bd6534d8..b9ddb40a7a233 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -33,6 +33,7 @@ class EPCtxHandler { std::string&& model_blob_str) const; std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const; InlinedVector GetEPCtxNodes() const; + bool CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const; private: const std::string openvino_sdk_version_; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 38b5f9a52eb3e..306fa6113b347 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -47,7 +47,6 @@ void printDebugInfo(const ov::CompiledModel& obj) { continue; OPENVINO_SUPPRESS_DEPRECATED_END std::cout << " " << item2.first << ": " << item2.second.as() << std::endl; - } } } else { std::cout << " " << cfg << ": " << prop.as() << std::endl; @@ -101,10 +100,10 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr& model, LogBasicModelInfo(model); } - LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl; bool model_status = IsStateful(model); LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False"); if (!model_status) { + LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl; PatchStatefulDecoder(model); } @@ -198,15 +197,69 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream, return OvExceptionBoundary([&]() { ov::CompiledModel obj; obj = core.import_model(model_stream, hw_target, device_config); + OVExeNetwork exe(obj, hw_target); #ifndef NDEBUG printDebugInfo(exe.Get()); #endif - OVExeNetwork exe(obj, hw_target); return exe; }, "Exception while Loading Network for graph {}", name); } +OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream& model_stream, + std::string& hw_target, + const ov::AnyMap& device_config, + bool enable_causallm, + std::filesystem::path model_file_path) { + return OvExceptionBoundary([&]() { + OVExeNetwork exe; + + bool isXML = backend_utils::IsModelStreamXML(model_stream); + + // Helper function to check if file exists and is readable + const auto check_file_access = [&model_file_path](const std::filesystem::path& path) { + try { + if (!std::filesystem::exists(path) || std::filesystem::is_empty(path)) { + ORT_THROW(log_tag + "Required file missing or empty: " + path.string()); + } + std::ifstream file(path); + if (!file) { + ORT_THROW(log_tag + "Required file not readable: " + path.string()); + } + } catch (const std::exception& e) { + ORT_THROW(log_tag + "Exception while checking file access for: " + path.string() + " - " + e.what()); + } + }; + + if (isXML) { + // If the model is XML, we need to load it with the XML content in read_model() + // where weights from bin file is directly consumed + auto xml_file_path = model_file_path.parent_path() / (model_file_path.stem().string() + ".xml"); + + check_file_access(xml_file_path); + + LOGS_DEFAULT(INFO) << log_tag << "Reading OVIR from XML file path: " << xml_file_path.string(); + + // Load the model explicitly with XML contents + std::shared_ptr model = core.read_model(xml_file_path.string()); + + if (enable_causallm) { + exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config); + } else { + auto obj = core.compile_model(model, hw_target, device_config); + exe = OVExeNetwork(obj, hw_target); + } + } + +#ifndef NDEBUG + printDebugInfo(exe.Get()); +#endif + return exe; + }, + "Exception while Loading Network from OVIR model file: {}", model_file_path.string()); +} + + void OVCore::SetCache(const std::string& cache_dir_path) { core.set_property(ov::cache_dir(cache_dir_path)); } diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 581da59bb4cae..0e019342bc86e 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -86,6 +86,12 @@ struct OVCore : WeakSingleton { std::string hw_target, const ov::AnyMap& device_config, std::string name); + OVExeNetwork ImportEPCtxOVIREncapsulation(std::istream& model_stream, + std::string& hw_target, + const ov::AnyMap& device_config, + bool enable_causallm, + std::filesystem::path model_file_path); + std::vector GetAvailableDevices() const; std::vector GetAvailableDevices(const std::string& device_type) const; void SetCache(const std::string& cache_dir_path); From 278f6a7dee045537c4ac00e00fcf70173a1518c0 Mon Sep 17 00:00:00 2001 From: Yaru Du Date: Tue, 24 Jun 2025 19:09:59 -0700 Subject: [PATCH 050/138] Add operator hardSwish into OVEP (#709) * add harswitsh into support list * change version which indidates when ovep supports this operator for openvino * change support verison which indicates when OV supports the op --- onnxruntime/core/providers/openvino/ov_versions/data_ops.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index b88f0d04d21f2..99d6e4b7ab5ef 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -158,6 +158,7 @@ std::vector supported_op_mode = { {"InstanceNormalization", V_2020_4, {"CPU", "GPU"}}, {"HardSigmoid", V_2020_4, {"CPU", "GPU"}}, {"HardMax", V_2022_1, {"CPU", "GPU"}}, + {"HardSwish", V_2025_0, {"CPU", "GPU"}}, {"LayerNormalization", V_2023_0, {"CPU", "GPU"}}, {"LeakyRelu", V_2020_4, {"CPU", "GPU"}}, {"Less", V_2020_4, {"CPU", "GPU"}}, From 90869ffee3fcb551ef34eb69e7bebb45f64f40d7 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Tue, 24 Jun 2025 23:08:07 -0700 Subject: [PATCH 051/138] [CVS-169168] Change name of metadata file and add filepath validation (#717) * Change name of metadata file and add filepath validation * Save metadata file path in shared context for use across model compilation * Fix metadata file path initialization * Check that metadata file is created --- .../providers/openvino/backend_manager.cc | 17 +- .../openvino/backends/basic_backend.cc | 20 +- .../core/providers/openvino/contexts.h | 1 + .../openvino/onnx_ctx_model_helper.cc | 2 +- .../openvino/openvino_execution_provider.cc | 54 +- .../core/providers/openvino/ov_interface.cc | 734 +++++++++--------- .../core/providers/openvino/ov_interface.h | 4 +- .../openvino/ov_versions/capability.cc | 2 +- .../openvino/ov_versions/data_ops.cc | 2 +- 9 files changed, 426 insertions(+), 410 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 8887b183c4396..e150a7cd00ec6 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -45,7 +45,7 @@ BackendManager::BackendManager(SessionContext& session_context, subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph); // If the graph contains a OVIR wrapped node, we check if it has matching xml file name attribute subgraph_context_.is_ep_ctx_ovir_encapsulated = ep_ctx_handle_.CheckEPCacheContextAttribute(subgraph, - session_context_.onnx_model_path_name.filename().replace_extension("xml").string()); + session_context_.onnx_model_path_name.filename().replace_extension("xml").string()); subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) { // return empty if graph has no inputs or if types are not one of FP32/FP16 @@ -91,21 +91,20 @@ BackendManager::BackendManager(SessionContext& session_context, std::string device_type = session_context_.device_type; auto& sw = shared_context_.shared_weights; - if (session_context_.so_share_ep_contexts) { + if (session_context_.so_share_ep_contexts && !sw.metadata.empty()) { std::filesystem::path weight_filename = session_context_.onnx_model_path_name.parent_path(); - if (sw.external_weight_filename.empty() && !sw.metadata.empty()) { + if (sw.external_weight_filename.empty()) { // Reasonable assumption that all metadata entries have the same external file location sw.external_weight_filename = sw.metadata.begin()->second.location; } weight_filename /= sw.external_weight_filename; std::ifstream weight_file(weight_filename); - if (weight_file) { - if (!sw.mapped_weights) { - sw.mapped_weights = std::make_unique(weight_filename); - } - backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights); + ORT_ENFORCE(weight_file, "Initializer file not found: ", weight_filename.string()); + if (!sw.mapped_weights) { + sw.mapped_weights = std::make_unique(weight_filename); } + backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights); } if (ModelHasSymbolicInputDims(subgraph)) { @@ -196,7 +195,7 @@ BackendManager::BackendManager(SessionContext& session_context, } } if (session_context_.so_context_enable && - (subgraph_context_.is_ep_ctx_ovir_encapsulated || !subgraph_context_.is_ep_ctx_graph)) { + (subgraph_context_.is_ep_ctx_ovir_encapsulated || !subgraph_context_.is_ep_ctx_graph)) { auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph); if (!status.IsOK()) { ORT_THROW(status); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 00a18bb0a45b6..ee74a1b1ee4b3 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -78,24 +78,24 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr // specify absolute path for so_context_file_path. auto model_file_path = [this]() { if (!session_context_.onnx_model_path_name.empty() && - std::filesystem::exists(session_context_.onnx_model_path_name)) return session_context_.onnx_model_path_name; + std::filesystem::exists(session_context_.onnx_model_path_name)) return session_context_.onnx_model_path_name; ORT_ENFORCE(!session_context_.so_context_file_path.empty() && - std::filesystem::path(session_context_.so_context_file_path).is_absolute() && - std::filesystem::exists(session_context_.so_context_file_path), log_tag + - "Context file path must be non-empty & absolute, when using CreateSessionFormArray() API explicitly." - " Please set a valid absolute path for ep.context_file_path in session options."); + std::filesystem::path(session_context_.so_context_file_path).is_absolute() && + std::filesystem::exists(session_context_.so_context_file_path), + log_tag + + "Context file path must be non-empty & absolute, when using CreateSessionFormArray() API explicitly." + " Please set a valid absolute path for ep.context_file_path in session options."); // Return absolute context file path as input to ImportEPCtxOVIREncapsulation() function. return session_context_.so_context_file_path; - }; // If the EPContext node with OVIR Encapsulation, then create // an executable network from EP_CACHE_CONTEXT using read_model() & compile_model() exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream, - hw_target, - device_config, - enable_causallm, - model_file_path()); + hw_target, + device_config, + enable_causallm, + model_file_path()); } else { // If the blob is held in an EPContext node, then skip FE+Compile // and directly move on to creating a backend with the executable blob diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index e2369cf728ea6..6a2b375d733f9 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -64,6 +64,7 @@ class SharedContext : public WeakSingleton { fs::path external_weight_filename; std::unique_ptr mapped_weights; Metadata::Map metadata; + fs::path metadata_filepath; } shared_weights; }; diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 49a4cb0a7e95a..9e70756a254aa 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -131,7 +131,7 @@ std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesy // exported with must match the version that is currently running. ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_), "EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() + - ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); + ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); } LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 5c8293a213f40..7f6a7909f1dec 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -102,15 +102,24 @@ common::Status OpenVINOExecutionProvider::Compile( graph_body_viewer_0.DomainToVersionMap().at(kOnnxDomain); } - // Temporary code to read metadata before it moves to the .bin - auto& metadata = shared_context_->shared_weights.metadata; - if (session_context_.so_share_ep_contexts && metadata.empty()) { - // Metadata is always read from model location, this could be a source or epctx model - fs::path metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin"; - std::ifstream file(metadata_filename, std::ios::binary); - if (file) { - file >> metadata; + // The block below is executed during EP context model inference + auto& metadata = shared_context_->shared_weights.metadata; // Metadata object in memory + if (session_context_.so_share_ep_contexts && + !session_context_.so_context_enable && + metadata.empty()) { + fs::path context_model_file_path = session_context_.so_context_file_path; + if (context_model_file_path.empty()) { + // If ep.context_file_path is not set the input model path is used + context_model_file_path = session_context_.onnx_model_path_name; } + + // Metadata is always read from model location, this could be a source or epctx model + fs::path metadata_filename = context_model_file_path.stem().string() + "_metadata.bin"; + fs::path metadata_file_path = context_model_file_path.parent_path() / metadata_filename; + std::ifstream file(metadata_file_path, std::ios::binary); + ORT_RETURN_IF_NOT(file, "Metadata file was not found: " + metadata_file_path.string()); + shared_context_->shared_weights.metadata_filepath = metadata_file_path; + file >> metadata; } struct OpenVINOEPFunctionState { @@ -173,22 +182,29 @@ common::Status OpenVINOExecutionProvider::Compile( } } - if (session_context_.so_share_ep_contexts) { - fs::path metadata_filename; - if (session_context_.so_context_file_path.empty()) { - metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin"; - } else { - metadata_filename = session_context_.so_context_file_path.parent_path() / "metadata.bin"; + // The block below is executed during EP context model generation + if (session_context_.so_context_enable && + session_context_.so_share_ep_contexts && + !metadata.empty()) { + // For models after the first the metadata name comes from the shared context + fs::path metadata_file_path = shared_context_->shared_weights.metadata_filepath; + if (metadata_file_path.empty()) { + metadata_file_path = session_context_.so_context_file_path; + if (metadata_file_path.empty()) { + metadata_file_path = session_context_.onnx_model_path_name; + } + auto metadata_filename = metadata_file_path.stem().string() + "_metadata.bin"; + metadata_file_path.replace_filename(metadata_filename); + shared_context_->shared_weights.metadata_filepath = metadata_file_path; } // Metadata is generated only for shared contexts - // If saving metadata then save it to the provided path or ose the original model path + // If saving metadata then save it to the provided path or use the original model path // Multiple calls to Compile() will update the metadata and for the last call // the resulting file will contain the aggregated content - std::ofstream file(metadata_filename, std::ios::binary); - if (file) { - file << metadata; - } + std::ofstream file{metadata_file_path, std::ios::binary}; + ORT_RETURN_IF_NOT(file, "Metadata file could not be written: ", metadata_file_path); + file << metadata; } return status; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 306fa6113b347..918940b9d9917 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -16,7 +16,7 @@ namespace onnxruntime { namespace openvino_ep { template -inline auto OvExceptionBoundary(Func &&func, std::format_string&& fmt, Args&&... args) { +inline auto OvExceptionBoundary(Func&& func, std::format_string&& fmt, Args&&... args) { try { return func(); } catch (const ov::Exception& e) { @@ -47,462 +47,462 @@ void printDebugInfo(const ov::CompiledModel& obj) { continue; OPENVINO_SUPPRESS_DEPRECATED_END std::cout << " " << item2.first << ": " << item2.second.as() << std::endl; + } + } + else { + std::cout << " " << cfg << ": " << prop.as() << std::endl; } - } else { - std::cout << " " << cfg << ": " << prop.as() << std::endl; } } } -} #endif -// Function to check if a given OV property is enabled -std::optional queryOVProperty(const std::string& property, const std::string& device_type) { - try { - // Get the property value - auto supported_properties = OVCore::Get()->core.get_property(device_type, ov::supported_properties); - return std::find(supported_properties.begin(), supported_properties.end(), property) != supported_properties.end(); - } catch (const std::exception&) { - return std::nullopt; // Property not found or invalid - } -} - -std::shared_ptr OVCore::ReadModel(std::string&& model, const std::string& model_path) { - return OvExceptionBoundary([&]() { - std::istringstream modelStringStream(std::move(model)); - std::istream& modelStream = modelStringStream; - // Try to load with FrontEndManager - ov::frontend::FrontEndManager manager; - ov::frontend::FrontEnd::Ptr FE; - ov::frontend::InputModel::Ptr inputModel; - - ov::AnyVector params{&modelStream, model_path}; - - FE = manager.load_by_model(params); - if (FE) { - inputModel = FE->load(params); - return FE->convert(inputModel); - } else { - ORT_THROW(log_tag + "Unknown exception while Reading network"); + // Function to check if a given OV property is enabled + std::optional queryOVProperty(const std::string& property, const std::string& device_type) { + try { + // Get the property value + auto supported_properties = OVCore::Get()->core.get_property(device_type, ov::supported_properties); + return std::find(supported_properties.begin(), supported_properties.end(), property) != supported_properties.end(); + } catch (const std::exception&) { + return std::nullopt; // Property not found or invalid } - }, - "Exception while Reading network"); -} - -OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr& model, - std::string& hw_target, - const ov::AnyMap& device_config) { - ov::CompiledModel compiled_model; - ov::AnyMap config = device_config; - - if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { - std::cout << "Stateless OV Model Statistic:" << std::endl; - LogBasicModelInfo(model); } - bool model_status = IsStateful(model); - LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False"); - if (!model_status) { - LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl; - PatchStatefulDecoder(model); - } - - if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { - std::cout << "Stateful OV Model Statistic:" << std::endl; - LogBasicModelInfo(model); + std::shared_ptr OVCore::ReadModel(std::string && model, const std::string& model_path) { + return OvExceptionBoundary([&]() { + std::istringstream modelStringStream(std::move(model)); + std::istream& modelStream = modelStringStream; + // Try to load with FrontEndManager + ov::frontend::FrontEndManager manager; + ov::frontend::FrontEnd::Ptr FE; + ov::frontend::InputModel::Ptr inputModel; + + ov::AnyVector params{&modelStream, model_path}; + + FE = manager.load_by_model(params); + if (FE) { + inputModel = FE->load(params); + return FE->convert(inputModel); + } else { + ORT_THROW(log_tag + "Unknown exception while Reading network"); + } + }, + "Exception while Reading network"); } - auto kv_pos = GetKVAxesPos(model); + OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr & model, + std::string & hw_target, + const ov::AnyMap& device_config) { + ov::CompiledModel compiled_model; + ov::AnyMap config = device_config; - if (hw_target.find("NPU") != std::string::npos) { - KVDesc kv_desc; - auto parse_genai_config = [&](const std::string& key, unsigned int default_value) { - return (config.count(key) && !config.at(key).empty() && config.at(key).as() != "0") ? config.at(key).as() : default_value; - }; - - kv_desc.max_prompt_len = parse_genai_config("MAX_PROMPT_LEN", CausalLMConfig().max_prompt_len); - kv_desc.min_response_len = parse_genai_config("MIN_RESPONSE_LEN", CausalLMConfig().min_response_len); + if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "Stateless OV Model Statistic:" << std::endl; + LogBasicModelInfo(model); + } - // For compilation, MAX_PROMPT_LEN & MIN_RESPONSE_LEN should not be 0 - if (kv_desc.max_prompt_len == 0 || kv_desc.min_response_len == 0) { - ORT_THROW(log_tag + "MAX_PROMPT_LEN and MIN_RESPONSE_LEN cannot be 0 or empty"); + bool model_status = IsStateful(model); + LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False"); + if (!model_status) { + LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl; + PatchStatefulDecoder(model); } if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { - std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl; - std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl; - std::cout << "kv_desc.max_prompt_len:\t" << kv_desc.max_prompt_len << std::endl; - std::cout << "kv_desc.min_response_len:\t" << kv_desc.min_response_len << std::endl; + std::cout << "Stateful OV Model Statistic:" << std::endl; + LogBasicModelInfo(model); } - UpdateNPUConfig(config, kv_pos, kv_desc); - } else { - // This patches the OV IR model so that it only produces the logits required for sampling. - // Actually either way that happens within NPUW::LLMCompiledModel creation for NPU device, - // while this is here mostly to align this behavior for other devices viz. (CPU, GPU). - ApplySliceBeforeMatmulTransformation(model); - } + auto kv_pos = GetKVAxesPos(model); - LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow"; - compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config); - OVExeNetwork exe(compiled_model, hw_target, true); - return exe; -} + if (hw_target.find("NPU") != std::string::npos) { + KVDesc kv_desc; + auto parse_genai_config = [&](const std::string& key, unsigned int default_value) { + return (config.count(key) && !config.at(key).empty() && config.at(key).as() != "0") ? config.at(key).as() : default_value; + }; + + kv_desc.max_prompt_len = parse_genai_config("MAX_PROMPT_LEN", CausalLMConfig().max_prompt_len); + kv_desc.min_response_len = parse_genai_config("MIN_RESPONSE_LEN", CausalLMConfig().min_response_len); -OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_network, - std::string& hw_target, - ov::AnyMap& device_config, - bool enable_causallm, - const std::string& name) { - return OvExceptionBoundary([&]() { - OVExeNetwork exe; - if (enable_causallm) { - auto mutable_model = ie_cnn_network->clone(); - exe = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config); + // For compilation, MAX_PROMPT_LEN & MIN_RESPONSE_LEN should not be 0 + if (kv_desc.max_prompt_len == 0 || kv_desc.min_response_len == 0) { + ORT_THROW(log_tag + "MAX_PROMPT_LEN and MIN_RESPONSE_LEN cannot be 0 or empty"); + } + + if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl; + std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl; + std::cout << "kv_desc.max_prompt_len:\t" << kv_desc.max_prompt_len << std::endl; + std::cout << "kv_desc.min_response_len:\t" << kv_desc.min_response_len << std::endl; + } + + UpdateNPUConfig(config, kv_pos, kv_desc); } else { - auto obj = core.compile_model(ie_cnn_network, hw_target, device_config); - exe = OVExeNetwork(obj, hw_target); + // This patches the OV IR model so that it only produces the logits required for sampling. + // Actually either way that happens within NPUW::LLMCompiledModel creation for NPU device, + // while this is here mostly to align this behavior for other devices viz. (CPU, GPU). + ApplySliceBeforeMatmulTransformation(model); } + LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow"; + compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config); + OVExeNetwork exe(compiled_model, hw_target, true); + return exe; + } + + OVExeNetwork OVCore::CompileModel(std::shared_ptr & ie_cnn_network, + std::string & hw_target, + ov::AnyMap & device_config, + bool enable_causallm, + const std::string& name) { + return OvExceptionBoundary([&]() { + OVExeNetwork exe; + if (enable_causallm) { + auto mutable_model = ie_cnn_network->clone(); + exe = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config); + } else { + auto obj = core.compile_model(ie_cnn_network, hw_target, device_config); + exe = OVExeNetwork(obj, hw_target); + } + #ifndef NDEBUG - printDebugInfo(exe.Get()); + printDebugInfo(exe.Get()); #endif - return exe; - }, - "Exception while Loading Network for graph {}", name); -} + return exe; + }, + "Exception while Loading Network for graph {}", name); + } -OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, - std::string& hw_target, - ov::AnyMap& device_config, - const std::string& name) { - return OvExceptionBoundary([&]() { - ov::CompiledModel obj; + OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, + std::string& hw_target, + ov::AnyMap& device_config, + const std::string& name) { + return OvExceptionBoundary([&]() { + ov::CompiledModel obj; - obj = core.compile_model(onnx_model, ov::Tensor(), hw_target, device_config); + obj = core.compile_model(onnx_model, ov::Tensor(), hw_target, device_config); #ifndef NDEBUG - printDebugInfo(obj); + printDebugInfo(obj); #endif - OVExeNetwork exe(obj, hw_target); - return exe; - }, - "Exception while Loading Network for graph {}", name); -} + OVExeNetwork exe(obj, hw_target); + return exe; + }, + "Exception while Loading Network for graph {}", name); + } -OVExeNetwork OVCore::ImportModel(std::istream& model_stream, - std::string hw_target, - const ov::AnyMap& device_config, - std::string name) { - return OvExceptionBoundary([&]() { - ov::CompiledModel obj; - obj = core.import_model(model_stream, hw_target, device_config); - OVExeNetwork exe(obj, hw_target); + OVExeNetwork OVCore::ImportModel(std::istream & model_stream, + std::string hw_target, + const ov::AnyMap& device_config, + std::string name) { + return OvExceptionBoundary([&]() { + ov::CompiledModel obj; + obj = core.import_model(model_stream, hw_target, device_config); + OVExeNetwork exe(obj, hw_target); #ifndef NDEBUG - printDebugInfo(exe.Get()); + printDebugInfo(exe.Get()); #endif - return exe; - }, - "Exception while Loading Network for graph {}", name); -} + return exe; + }, + "Exception while Loading Network for graph {}", name); + } -OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream& model_stream, - std::string& hw_target, - const ov::AnyMap& device_config, - bool enable_causallm, - std::filesystem::path model_file_path) { - return OvExceptionBoundary([&]() { - OVExeNetwork exe; - - bool isXML = backend_utils::IsModelStreamXML(model_stream); - - // Helper function to check if file exists and is readable - const auto check_file_access = [&model_file_path](const std::filesystem::path& path) { - try { - if (!std::filesystem::exists(path) || std::filesystem::is_empty(path)) { - ORT_THROW(log_tag + "Required file missing or empty: " + path.string()); - } - std::ifstream file(path); - if (!file) { - ORT_THROW(log_tag + "Required file not readable: " + path.string()); + OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream & model_stream, + std::string & hw_target, + const ov::AnyMap& device_config, + bool enable_causallm, + std::filesystem::path model_file_path) { + return OvExceptionBoundary([&]() { + OVExeNetwork exe; + + bool isXML = backend_utils::IsModelStreamXML(model_stream); + + // Helper function to check if file exists and is readable + const auto check_file_access = [&model_file_path](const std::filesystem::path& path) { + try { + if (!std::filesystem::exists(path) || std::filesystem::is_empty(path)) { + ORT_THROW(log_tag + "Required file missing or empty: " + path.string()); + } + std::ifstream file(path); + if (!file) { + ORT_THROW(log_tag + "Required file not readable: " + path.string()); + } + } catch (const std::exception& e) { + ORT_THROW(log_tag + "Exception while checking file access for: " + path.string() + " - " + e.what()); } - } catch (const std::exception& e) { - ORT_THROW(log_tag + "Exception while checking file access for: " + path.string() + " - " + e.what()); - } - }; + }; - if (isXML) { - // If the model is XML, we need to load it with the XML content in read_model() - // where weights from bin file is directly consumed - auto xml_file_path = model_file_path.parent_path() / (model_file_path.stem().string() + ".xml"); + if (isXML) { + // If the model is XML, we need to load it with the XML content in read_model() + // where weights from bin file is directly consumed + auto xml_file_path = model_file_path.parent_path() / (model_file_path.stem().string() + ".xml"); - check_file_access(xml_file_path); + check_file_access(xml_file_path); - LOGS_DEFAULT(INFO) << log_tag << "Reading OVIR from XML file path: " << xml_file_path.string(); + LOGS_DEFAULT(INFO) << log_tag << "Reading OVIR from XML file path: " << xml_file_path.string(); - // Load the model explicitly with XML contents - std::shared_ptr model = core.read_model(xml_file_path.string()); + // Load the model explicitly with XML contents + std::shared_ptr model = core.read_model(xml_file_path.string()); - if (enable_causallm) { - exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config); - } else { - auto obj = core.compile_model(model, hw_target, device_config); - exe = OVExeNetwork(obj, hw_target); + if (enable_causallm) { + exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config); + } else { + auto obj = core.compile_model(model, hw_target, device_config); + exe = OVExeNetwork(obj, hw_target); + } } - } #ifndef NDEBUG - printDebugInfo(exe.Get()); + printDebugInfo(exe.Get()); #endif - return exe; - }, - "Exception while Loading Network from OVIR model file: {}", model_file_path.string()); -} - - -void OVCore::SetCache(const std::string& cache_dir_path) { - core.set_property(ov::cache_dir(cache_dir_path)); -} - -std::vector OVCore::GetAvailableDevices() const { - std::vector available_devices = core.get_available_devices(); - return available_devices; -} - -std::vector OVCore::GetAvailableDevices(const std::string& device_type) const { - std::vector available_devices; - std::vector devicesIDs; - // Uses logic from OpenVINO to only return available devices of the specified type (e.g. CPU, NPU or GPU) - try { - devicesIDs = core.get_property(device_type, ov::available_devices); - } catch (const ov::Exception&) { - // plugin is not created by e.g. invalid env - // Empty device list will be returned - } catch (const std::exception& ex) { - ORT_THROW(log_tag + "An exception occurred while trying to create the ", - device_type, - " device: ", - ex.what()); - } catch (...) { - ORT_THROW(log_tag + "Unknown exception occurred while trying to create the ", - device_type, - " device"); + return exe; + }, + "Exception while Loading Network from OVIR model file: {}", model_file_path.string()); } - if (devicesIDs.size() > 1 || - (devicesIDs.size() == 1 && devicesIDs[0] == "0")) { - for (const auto& deviceID : devicesIDs) { - available_devices.push_back(device_type + '.' + deviceID); - } - } - if (!devicesIDs.empty()) { - available_devices.push_back(device_type); + void OVCore::SetCache(const std::string& cache_dir_path) { + core.set_property(ov::cache_dir(cache_dir_path)); } - return available_devices; -} - -void OVCore::SetStreams(const std::string& device_type, int num_streams) { - core.set_property(device_type, {ov::num_streams(num_streams)}); -} + std::vector OVCore::GetAvailableDevices() const { + std::vector available_devices = core.get_available_devices(); + return available_devices; + } -std::shared_ptr OVExeNetwork::CreateInferRequest() { - return OvExceptionBoundary([&]() { - auto infReq = compiled_model_obj.create_infer_request(); - std::shared_ptr ovInfReq; - if (is_stateful_causallm) { - ovInfReq = std::make_shared(std::move(infReq), target_device); - } else { - ovInfReq = std::make_shared(std::move(infReq)); + std::vector OVCore::GetAvailableDevices(const std::string& device_type) const { + std::vector available_devices; + std::vector devicesIDs; + // Uses logic from OpenVINO to only return available devices of the specified type (e.g. CPU, NPU or GPU) + try { + devicesIDs = core.get_property(device_type, ov::available_devices); + } catch (const ov::Exception&) { + // plugin is not created by e.g. invalid env + // Empty device list will be returned + } catch (const std::exception& ex) { + ORT_THROW(log_tag + "An exception occurred while trying to create the ", + device_type, + " device: ", + ex.what()); + } catch (...) { + ORT_THROW(log_tag + "Unknown exception occurred while trying to create the ", + device_type, + " device"); } - return ovInfReq; - }, - - "Exception while creating InferRequest object"); -} -OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) { - return OvExceptionBoundary([&]() { - auto tobj = ovInfReq.get_tensor(input_name); - OVTensorPtr blob = std::make_shared(tobj); - return blob; - }, - " Cannot access IE Blob for input: {}", input_name); -} + if (devicesIDs.size() > 1 || + (devicesIDs.size() == 1 && devicesIDs[0] == "0")) { + for (const auto& deviceID : devicesIDs) { + available_devices.push_back(device_type + '.' + deviceID); + } + } + if (!devicesIDs.empty()) { + available_devices.push_back(device_type); + } -std::string OVInferRequest::GetInputTensorName(uint32_t index) { - return OvExceptionBoundary([&]() { - const auto& model = ovInfReq.get_compiled_model(); - return *model.input(index).get_names().begin(); - }, - " Cannot access IE Blob for input number: {}", index); -} + return available_devices; + } -void OVInferRequest::SetTensor(const std::string& name, OVTensorPtr& blob) { - OvExceptionBoundary([&]() { - ovInfReq.set_tensor(name, *(blob.get())); - }, - " Cannot set Remote Blob for output: {}", name); -} + void OVCore::SetStreams(const std::string& device_type, int num_streams) { + core.set_property(device_type, {ov::num_streams(num_streams)}); + } -uint32_t OVInferRequest::GetNumInputs() { - return static_cast(ovInfReq.get_compiled_model().inputs().size()); -} + std::shared_ptr OVExeNetwork::CreateInferRequest() { + return OvExceptionBoundary([&]() { + auto infReq = compiled_model_obj.create_infer_request(); + std::shared_ptr ovInfReq; + if (is_stateful_causallm) { + ovInfReq = std::make_shared(std::move(infReq), target_device); + } else { + ovInfReq = std::make_shared(std::move(infReq)); + } + return ovInfReq; + }, -void OVInferRequest::Infer() { - OvExceptionBoundary([&]() { - ovInfReq.infer(); - }, - "In Error Couldn't start Inference"); -} + "Exception while creating InferRequest object"); + } -StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) - : OVInferRequest(std::move(infer_request)), target_device(device) { - bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); - if (gpu_or_npu) { - prefill_use_full_chat_history = true; + OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) { + return OvExceptionBoundary([&]() { + auto tobj = ovInfReq.get_tensor(input_name); + OVTensorPtr blob = std::make_shared(tobj); + return blob; + }, + " Cannot access IE Blob for input: {}", input_name); } -} -void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type, - const std::vector& shape, int32_t fill_value) { - ov::Tensor tensor = ov::Tensor(type, shape); - std::fill_n(tensor.data(), tensor.get_size(), fill_value); - ovInfReq.set_tensor(tensor_name, tensor); -} + std::string OVInferRequest::GetInputTensorName(uint32_t index) { + return OvExceptionBoundary([&]() { + const auto& model = ovInfReq.get_compiled_model(); + return *model.input(index).get_names().begin(); + }, + " Cannot access IE Blob for input number: {}", index); + } -void StatefulOVInferRequest::CacheTensor(const std::string& tensor_name, std::vector& cache) { - auto tensor = ovInfReq.get_tensor(tensor_name); - auto* pData = tensor.data(); - for (size_t i = 0; i < tensor.get_size(); i++) { - cache.emplace_back(pData[i]); + void OVInferRequest::SetTensor(const std::string& name, OVTensorPtr& blob) { + OvExceptionBoundary([&]() { + ovInfReq.set_tensor(name, *(blob.get())); + }, + " Cannot set Remote Blob for output: {}", name); } -} -void StatefulOVInferRequest::SetTensorFromCache(const std::string& tensor_name, - const std::vector& cache_data) { - auto tensor = ovInfReq.get_tensor(tensor_name); - auto new_shape = tensor.get_shape(); - new_shape[1] = cache_data.size(); + uint32_t OVInferRequest::GetNumInputs() { + return static_cast(ovInfReq.get_compiled_model().inputs().size()); + } - auto new_tensor = ov::Tensor(tensor.get_element_type(), new_shape); - auto* pNewData = new_tensor.data(); - std::memcpy(pNewData, cache_data.data(), cache_data.size() * sizeof(int64_t)); + void OVInferRequest::Infer() { + OvExceptionBoundary([&]() { + ovInfReq.infer(); + }, + "In Error Couldn't start Inference"); + } - ovInfReq.set_tensor(tensor_name, new_tensor); -} + StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) + : OVInferRequest(std::move(infer_request)), target_device(device) { + bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); + if (gpu_or_npu) { + prefill_use_full_chat_history = true; + } + } -std::optional StatefulOVInferRequest::FindTensor(const std::string& tensor_name) { - // Check if tensor exists by examining input names in the compiled model - const auto& model = ovInfReq.get_compiled_model(); - bool tensor_exists = false; + void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type, + const std::vector& shape, int32_t fill_value) { + ov::Tensor tensor = ov::Tensor(type, shape); + std::fill_n(tensor.data(), tensor.get_size(), fill_value); + ovInfReq.set_tensor(tensor_name, tensor); + } - for (const auto& input : model.inputs()) { - const auto& names = input.get_names(); - if (names.find(tensor_name) != names.end()) { - tensor_exists = true; - break; + void StatefulOVInferRequest::CacheTensor(const std::string& tensor_name, std::vector& cache) { + auto tensor = ovInfReq.get_tensor(tensor_name); + auto* pData = tensor.data(); + for (size_t i = 0; i < tensor.get_size(); i++) { + cache.emplace_back(pData[i]); } } - if (tensor_exists) { - return ovInfReq.get_tensor(tensor_name); - } + void StatefulOVInferRequest::SetTensorFromCache(const std::string& tensor_name, + const std::vector& cache_data) { + auto tensor = ovInfReq.get_tensor(tensor_name); + auto new_shape = tensor.get_shape(); + new_shape[1] = cache_data.size(); - return std::nullopt; -} + auto new_tensor = ov::Tensor(tensor.get_element_type(), new_shape); + auto* pNewData = new_tensor.data(); + std::memcpy(pNewData, cache_data.data(), cache_data.size() * sizeof(int64_t)); -void StatefulOVInferRequest::PreProcessInferRequest() { - // Workaround: Setting the value here as it cannot be set at the ORT GenAI layer currently. - // TODO(ankit): Address this issue and implement the fix at the appropriate layer. - FillTensor("beam_idx", ov::element::i32, {1}, 0); + ovInfReq.set_tensor(tensor_name, new_tensor); + } - // If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids. - if (prefill_use_full_chat_history) { - auto input_ids_tensor = ovInfReq.get_tensor("input_ids"); - CacheTensor("input_ids", cached_input_ids); + std::optional StatefulOVInferRequest::FindTensor(const std::string& tensor_name) { + // Check if tensor exists by examining input names in the compiled model + const auto& model = ovInfReq.get_compiled_model(); + bool tensor_exists = false; - // "position_ids" (GQA with Rotary Embeddings doesnt have position_ids) - check if exists - auto position_ids_opt = FindTensor("position_ids"); - bool has_position_ids = position_ids_opt.has_value(); + for (const auto& input : model.inputs()) { + const auto& names = input.get_names(); + if (names.find(tensor_name) != names.end()) { + tensor_exists = true; + break; + } + } - if (has_position_ids) { - CacheTensor("position_ids", cached_position_ids); + if (tensor_exists) { + return ovInfReq.get_tensor(tensor_name); } - // If we're about to run the prefill model - if (input_ids_tensor.get_size() > 1) { - // Check if the size of the current "input_ids" tensor does not match the size of the cached "input_ids". - // This indicates that we are running a subsequent prompt (not the initial prefill). - if (input_ids_tensor.get_shape()[1] != cached_input_ids.size()) { - // Clear the internal KVCache state. For NPU device, this operation is a no-op. - ovInfReq.reset_state(); + return std::nullopt; + } + + void StatefulOVInferRequest::PreProcessInferRequest() { + // Workaround: Setting the value here as it cannot be set at the ORT GenAI layer currently. + // TODO(ankit): Address this issue and implement the fix at the appropriate layer. + FillTensor("beam_idx", ov::element::i32, {1}, 0); + + // If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids. + if (prefill_use_full_chat_history) { + auto input_ids_tensor = ovInfReq.get_tensor("input_ids"); + CacheTensor("input_ids", cached_input_ids); - // Set tensors using cached values - SetTensorFromCache("input_ids", cached_input_ids); + // "position_ids" (GQA with Rotary Embeddings doesnt have position_ids) - check if exists + auto position_ids_opt = FindTensor("position_ids"); + bool has_position_ids = position_ids_opt.has_value(); - // Only set position_ids if it exists and we have cached values - if (has_position_ids && !cached_position_ids.empty()) { - SetTensorFromCache("position_ids", cached_position_ids); + if (has_position_ids) { + CacheTensor("position_ids", cached_position_ids); + } + + // If we're about to run the prefill model + if (input_ids_tensor.get_size() > 1) { + // Check if the size of the current "input_ids" tensor does not match the size of the cached "input_ids". + // This indicates that we are running a subsequent prompt (not the initial prefill). + if (input_ids_tensor.get_shape()[1] != cached_input_ids.size()) { + // Clear the internal KVCache state. For NPU device, this operation is a no-op. + ovInfReq.reset_state(); + + // Set tensors using cached values + SetTensorFromCache("input_ids", cached_input_ids); + + // Only set position_ids if it exists and we have cached values + if (has_position_ids && !cached_position_ids.empty()) { + SetTensorFromCache("position_ids", cached_position_ids); + } } } } } -} -void StatefulOVInferRequest::Infer() { - PreProcessInferRequest(); - OVInferRequest::Infer(); -} + void StatefulOVInferRequest::Infer() { + PreProcessInferRequest(); + OVInferRequest::Infer(); + } -void StatefulOVInferRequest::RewindKVCache(size_t index) { - LOGS_DEFAULT(INFO) << log_tag << "RewindKVCache: Rewinding OpenVINO-internal KVCache state to index=" << index; + void StatefulOVInferRequest::RewindKVCache(size_t index) { + LOGS_DEFAULT(INFO) << log_tag << "RewindKVCache: Rewinding OpenVINO-internal KVCache state to index=" << index; - if (prefill_use_full_chat_history) { - // Clear the internal KVCache state. For NPU device, this operation is a no-op. - ovInfReq.reset_state(); + if (prefill_use_full_chat_history) { + // Clear the internal KVCache state. For NPU device, this operation is a no-op. + ovInfReq.reset_state(); - // Resize the cached "input_ids" and "position_ids" to the specified index. - if (cached_input_ids.size() > index) { - cached_input_ids.resize(index); - } + // Resize the cached "input_ids" and "position_ids" to the specified index. + if (cached_input_ids.size() > index) { + cached_input_ids.resize(index); + } - if (cached_position_ids.size() > index) { - cached_position_ids.resize(index); - } - } else { - if (index == 0) { - // In this case, since we're resetting the entire KVCache, simply reset the state. - ovInfReq.reset_state(); + if (cached_position_ids.size() > index) { + cached_position_ids.resize(index); + } } else { - // Retrieve KVCache states and trim them to the specified index. - // The following logic is adapted from: - // https://github.com/openvinotoolkit/openvino.genai/blob/releases/2025/1/src/cpp/src/utils.cpp#L329 - auto states = ovInfReq.query_state(); - for (auto& state : states) { - ov::Tensor old_tensor = state.get_state(); - // Tensor shape: [batch_size, num_kv_heads, seq_len, head_size] - auto shape = old_tensor.get_shape(); - - if (shape[2] > index) { - // Update the sequence length dimension to the specified index. - shape[2] = index; - - ov::Coordinate new_shape_begin{0, 0, 0, 0}; - ov::Coordinate new_shape_end{shape}; - - // Create a trimmed tensor with the updated shape. - auto trimmed_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end); - - // Copy the trimmed tensor into a new tensor and update the state. - ov::Tensor new_tensor(old_tensor.get_element_type(), shape); - trimmed_tensor.copy_to(new_tensor); - - state.set_state(new_tensor); + if (index == 0) { + // In this case, since we're resetting the entire KVCache, simply reset the state. + ovInfReq.reset_state(); + } else { + // Retrieve KVCache states and trim them to the specified index. + // The following logic is adapted from: + // https://github.com/openvinotoolkit/openvino.genai/blob/releases/2025/1/src/cpp/src/utils.cpp#L329 + auto states = ovInfReq.query_state(); + for (auto& state : states) { + ov::Tensor old_tensor = state.get_state(); + // Tensor shape: [batch_size, num_kv_heads, seq_len, head_size] + auto shape = old_tensor.get_shape(); + + if (shape[2] > index) { + // Update the sequence length dimension to the specified index. + shape[2] = index; + + ov::Coordinate new_shape_begin{0, 0, 0, 0}; + ov::Coordinate new_shape_end{shape}; + + // Create a trimmed tensor with the updated shape. + auto trimmed_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end); + + // Copy the trimmed tensor into a new tensor and update the state. + ov::Tensor new_tensor(old_tensor.get_element_type(), shape); + trimmed_tensor.copy_to(new_tensor); + + state.set_state(new_tensor); + } } } } } -} } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 0e019342bc86e..fb1757199698b 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -117,7 +117,7 @@ class OVInferRequest { const void* ort_ptr; }; - protected: + protected: ov::InferRequest ovInfReq; std::unordered_map bindings_cache_; @@ -127,7 +127,7 @@ class OVInferRequest { std::string GetInputTensorName(uint32_t index); // Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set. - void SetTensor(const std::string& name, const ov::element::Type &type, const ov::Shape& shape, void* ort_ptr) { + void SetTensor(const std::string& name, const ov::element::Type& type, const ov::Shape& shape, void* ort_ptr) { auto& cached_binding = bindings_cache_[name]; if (cached_binding.ort_ptr != ort_ptr) { auto tensor_ptr = std::make_shared(type, shape, const_cast(ort_ptr)); diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 45ea822685710..88ddde8610c6e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -38,7 +38,7 @@ GetCapability::GetCapability(const EPCtxHandler& ep_ctx_handler, device_type_ = "CPU"; if (enable_qdq_optimizer) npu_qdq_optimizer_enabled = true; } else if (enable_qdq_optimizer && device_type_.find("GPU") != std::string::npos) { - npu_qdq_optimizer_enabled = true; // see data_ops.cc ~615 where we check for int16 types for gpu, this may change to a better approach later + npu_qdq_optimizer_enabled = true; // see data_ops.cc ~615 where we check for int16 types for gpu, this may change to a better approach later } #if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 5 diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 99d6e4b7ab5ef..27d8dd7822c41 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -615,7 +615,7 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { } // experimentally for GPU and qdq stripping mode allow int16 types if (npu_qdq_optimizer_enabled_ && (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16)) - return true; + return true; } #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { From e141e2ffdcc9ac49a7782a759c47885243d26797 Mon Sep 17 00:00:00 2001 From: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Date: Wed, 25 Jun 2025 04:04:49 -0400 Subject: [PATCH 052/138] fix: Fix logic in OnnxToOvNetworkBindings for stateful models (#719) Co-authored-by: Ankit Maheshkar --- .../openvino/backends/basic_backend.h | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index b1d5406fcf3e2..9f0369a3603af 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -68,13 +68,18 @@ struct OnnxToOvNetworkBindings { // For Stateful Model Compilation, the ONNX model includes KV cache (past/present) tensors. // However, these tensors are internally converted to a stateful representation, which removes them. - // To prevent runtime exceptions, we simply continue processing here. - if (!matched_names && session_context.enable_causallm && - std::any_of(special_io_names_.begin(), special_io_names_.end(), - [&onnx_name](const std::string& name) { return onnx_name.find(name) != std::string::npos; })) { - // This case also requires dynamic shape inference, so we'll mark the bindings as dynamic. - has_dynamic_io_ = true; - continue; + // It's also possible that the onnx model does not contain tensors such as beam_idx, whereas our converted + // stateful representation has introduced these new tensors, creating a name mismatch (matched_names=false). + // So, if there is a name mismatch, or the name matches our special io list, we simply continue processing + // here to prevent runtime exceptions. + if (session_context.enable_causallm) { + if (!matched_names || + std::any_of(special_io_names_.begin(), special_io_names_.end(), + [&onnx_name](const std::string& name) { return onnx_name.find(name) != std::string::npos; })) { + // This case also requires dynamic shape inference, so we'll mark the bindings as dynamic. + has_dynamic_io_ = true; + continue; + } } ORT_ENFORCE(matched_names, log_tag, From 7f86fad483076305db2bcacc73491c3a949b8dd0 Mon Sep 17 00:00:00 2001 From: "Dvoretckii, Mikhail" Date: Tue, 10 Jun 2025 07:52:26 -0700 Subject: [PATCH 053/138] [OVEP] Fix UnsupportedOpMode logic for Reshape Assumptions about number of inputs to Reshape nodes were causing crashes. The check has been adjusted to be functional with Reshape only having one input. --- onnxruntime/core/providers/openvino/ov_versions/data_ops.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 27d8dd7822c41..84001c1161efc 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -420,7 +420,8 @@ void DataOps::populate_op_mode_supported() { UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1}, [this](const Node* node, const InitializedTensorSet&) { - const auto& input_arg = node->InputDefs()[1]; + const auto& input_args = node->InputDefs(); + const auto& input_arg = (input_args.size() > 1) ? input_args[1] : input_args[0]; auto shape = input_arg->Shape(); // Reshape op with empty dim is Rejected for Myriad // [TODO] Is this condition required anymore with Myriad removed? From dfc8ff07f46c1535f5b94dba5a495130637251ba Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Fri, 27 Jun 2025 15:56:03 -0700 Subject: [PATCH 054/138] Fix metadata name when ep.context_file_path is not provided (#722) --- .../core/providers/openvino/openvino_execution_provider.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 7f6a7909f1dec..a0aa04293ac37 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -190,10 +190,12 @@ common::Status OpenVINOExecutionProvider::Compile( fs::path metadata_file_path = shared_context_->shared_weights.metadata_filepath; if (metadata_file_path.empty()) { metadata_file_path = session_context_.so_context_file_path; + std::string name_append{"_metadata.bin"}; if (metadata_file_path.empty()) { metadata_file_path = session_context_.onnx_model_path_name; + name_append = "_ctx" + name_append; } - auto metadata_filename = metadata_file_path.stem().string() + "_metadata.bin"; + auto metadata_filename = metadata_file_path.stem().string() + name_append; metadata_file_path.replace_filename(metadata_filename); shared_context_->shared_weights.metadata_filepath = metadata_file_path; } From 0c1fcfd6863716a6c8195e88beeb398c11f14885 Mon Sep 17 00:00:00 2001 From: Pallavi Gupta Date: Mon, 30 Jun 2025 18:07:51 -0700 Subject: [PATCH 055/138] EPCtx changes for dynamic model with reshape_input provider optional (#720) 1. Throw an ErrorMessage when user provide reshape_input with EPctx graph. 2. Do not Create dynamic_backend for EPctx graph. Update onnxruntime/core/providers/openvino/backend_manager.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../core/providers/openvino/backend_manager.cc | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index e150a7cd00ec6..9597f73c38073 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -84,6 +84,11 @@ BackendManager::BackendManager(SessionContext& session_context, ptr_stream_t model_stream; std::unique_ptr model_proto; if (subgraph_context_.is_ep_ctx_graph) { + if (!session_context_.reshape.empty()) { + std::string exception_str = + "[OpenVINO-EP] Bounded dynamic model execution using provider option reshape_input is not supported for OVEP EPContext model"; + ORT_THROW(exception_str); + } model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph); } else { model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger); @@ -110,7 +115,10 @@ BackendManager::BackendManager(SessionContext& session_context, if (ModelHasSymbolicInputDims(subgraph)) { subgraph_context_.has_dynamic_input_shape = true; LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; - if (!session_context_.disable_dynamic_shapes) { + if ((!session_context_.disable_dynamic_shapes && + (session_context_.device_type.find("CPU") != std::string::npos || + session_context_.device_type.find("GPU") != std::string::npos)) || + (subgraph_context_.is_ep_ctx_graph)) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; try { @@ -590,7 +598,7 @@ void BackendManager::Compute(OrtKernelContext* context) { // by rewriting the model to static shaped model at runtime based on input shape. // disable_dynamic_shapes should be set for devices that don't support dynamic shapes. bool need_dynamic_backend = subgraph_context_.has_dynamic_input_shape && - session_context_.disable_dynamic_shapes; + session_context_.disable_dynamic_shapes && !subgraph_context_.is_ep_ctx_graph; if (!need_dynamic_backend) { concrete_backend_->Infer(context); From 13e779245af4c4847d637045a659867e2810c60a Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Wed, 2 Jul 2025 13:56:51 +0530 Subject: [PATCH 056/138] fix: enable test & lint fixes (#725) --- onnxruntime/core/providers/openvino/backend_manager.cc | 4 ++-- onnxruntime/test/contrib_ops/attention_op_test.cc | 2 +- .../test/providers/cpu/reduction/reduction_ops_test.cc | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 9597f73c38073..2ba562525e9c3 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -116,8 +116,8 @@ BackendManager::BackendManager(SessionContext& session_context, subgraph_context_.has_dynamic_input_shape = true; LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; if ((!session_context_.disable_dynamic_shapes && - (session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) || + (session_context_.device_type.find("CPU") != std::string::npos || + session_context_.device_type.find("GPU") != std::string::npos)) || (subgraph_context_.is_ep_ctx_graph)) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 4dff0376fcd84..61e5fa05c66c1 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -2047,7 +2047,7 @@ TEST(AttentionTest, AttentionPastState_dynamic) { test.AddInput("past", past_dims, past_data); test.AddReferenceOutputs("testdata/attention_past_state.onnx", 0.005f); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + test.Run(); } #endif //! defined(__wasm__) diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index e2ee859fb26df..c56aa3fb5feac 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -4102,8 +4102,8 @@ TEST(ReductionOpTest, ReduceSum_noop_axes_input_initializer_opset_18) { test.Run( OpTester::ExpectResult::kExpectSuccess, "", - {kOpenVINOExecutionProvider} // OpenVINO: Disabled temporarily - ); + {kOpenVINOExecutionProvider} // OpenVINO: Disabled temporarily + ); } TEST(ReductionOpTest, ReduceSum_empty_axes_input_initializer_opset_18) { From 80daa9b575f3645196f199c59241433a9c2ded46 Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Wed, 2 Jul 2025 20:13:11 +0530 Subject: [PATCH 057/138] Reduce the peak memory even with CPU fallback by moving the fallback within the basic_backend.cc scope (#723) --- .../providers/openvino/backend_manager.cc | 62 +++---- .../openvino/backends/basic_backend.cc | 158 ++++++++++-------- .../openvino/backends/basic_backend.h | 1 + 3 files changed, 111 insertions(+), 110 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 2ba562525e9c3..041d9c07e41fe 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -153,52 +153,28 @@ BackendManager::BackendManager(SessionContext& session_context, model_stream); } catch (const OnnxRuntimeException& ex) { std::string exception_str = ex.what(); - bool eligible_for_cpu_fallback = device_type.find("NPU") != std::string::npos && - !session_context_.so_disable_cpu_ep_fallback && - !subgraph_context_.is_ep_ctx_graph; -#if defined(OPENVINO_DISABLE_NPU_FALLBACK) - eligible_for_cpu_fallback = false; -#else - if (eligible_for_cpu_fallback) { - LOGS_DEFAULT(VERBOSE) << exception_str; - LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." - << "Falling back to OV CPU for execution"; - session_context_.device_type = "CPU"; - session_context_.precision = "FP32"; - try { - concrete_backend_ = BackendFactory::MakeBackend(model_proto, - session_context_, - subgraph_context_, - shared_context_, - model_stream); - } catch (std::string const& msg) { - ORT_THROW(msg); - } - } -#endif - if (!eligible_for_cpu_fallback) { - if (device_type.find("NPU") != std::string::npos && - exception_str.find("intel_npu") != std::string::npos) { - // Handle NPU device related errors + + if (session_context_.device_type.find("NPU") != std::string::npos && + exception_str.find("intel_npu") != std::string::npos) { + // Handle NPU device related errors #ifndef NDEBUG - ORT_THROW(exception_str + "\nModel needs to be recompiled\n"); + ORT_THROW(exception_str + "\nModel needs to be recompiled\n"); #else - std::string error_message = "UNKNOWN NPU ERROR"; - std::string error_code = "code 0x0"; - std::regex error_message_pattern(R"(\bZE_\w*\b)"); - std::regex error_code_pattern("code 0x[0-9a-fA-F]+"); - std::smatch matches; - if (std::regex_search(exception_str, matches, error_message_pattern)) { - error_message = matches[0]; - } - if (std::regex_search(exception_str, matches, error_code_pattern)) { - error_code = matches[0]; - } - throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n"); -#endif - } else { - ORT_THROW(exception_str); + std::string error_message = "UNKNOWN NPU ERROR"; + std::string error_code = "code 0x0"; + std::regex error_message_pattern(R"(\bZE_\w*\b)"); + std::regex error_code_pattern("code 0x[0-9a-fA-F]+"); + std::smatch matches; + if (std::regex_search(exception_str, matches, error_message_pattern)) { + error_message = matches[0]; } + if (std::regex_search(exception_str, matches, error_code_pattern)) { + error_code = matches[0]; + } + throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n"); +#endif + } else { + ORT_THROW(exception_str); } } } diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index ee74a1b1ee4b3..df75f84a5fee0 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -36,42 +36,14 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr if (ValidateSubgraph(const_outputs_map_)) return; - // OV Config + // Pre-requisite is provider_option "context" must be set + auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) || + (session_context_.OpenVINO_Version.at(0) >= 2024 && + session_context_.OpenVINO_Version.at(1) > 2)); ov::AnyMap device_config; - PopulateConfigValue(device_config); - - // Enable caching - EnableCaching(); - - // Setting OpenCL queue throttling for GPU - EnableGPUThrottling(device_config); - - // Enable streams; default=1 unless overridden by user configuration - EnableStreams(); - - // Set the inference_num_threads property of the CPU - SetNumThreads(device_config); - - auto npuw_status = - std::any_of(device_config.begin(), device_config.end(), [&](const std::pair& pair) { - return (pair.first.find("NPU_USE_NPUW") != std::string::npos) && (pair.second.is()) && - (pair.second.as() == "YES"); - }); - - if (npuw_status) { - LOGS_DEFAULT(INFO) << log_tag << "NPUW Enabled during compilation"; - } - - try { - // IO_BUFFER is enabled on GPU HW. - // Pre-requisite is provider_option "context" must be set - auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) || - (session_context_.OpenVINO_Version.at(0) >= 2024 && - session_context_.OpenVINO_Version.at(1) > 2)); - bool disable_cpu_fallback = !(hw_target.find("NPU") != std::string::npos && - !session_context_.so_disable_cpu_ep_fallback && - !subgraph_context_.is_ep_ctx_graph); - if (subgraph_context_.is_ep_ctx_graph) { + SetOVDeviceConfiguration(device_config); + if (subgraph_context_.is_ep_ctx_graph) { + try { if (subgraph_context_.is_ep_ctx_ovir_encapsulated) { // model_file_path will use so_context_file_path if the onnx_model_path_name is not available, // especially in case of CreateSessionFormArray() where user must explicitly @@ -104,41 +76,67 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr device_config, subgraph_context_.subgraph_name); } - model_stream.reset(); // Delete stream after it is no longer needed - } else if (!session_context_.has_external_weights && - !subgraph_context_.has_dynamic_input_shape && - !session_context_.so_context_enable && - session_context_.reshape.empty() && - !enable_causallm && - auto_unified_compile) { - // Unified OV compile_model is efficient when ov model caching is enabled - // Unified OV compile_model API is supported with AUTO from version 2024.3 and above - // Inputs with static dimensions - // Not enabled for models with external weights and when ep context is set. - const std::string model = model_proto->SerializeAsString(); - // we have the serialized string, so we can release model proto to lower the peak memory consumption - if (disable_cpu_fallback) model_proto.reset(); - exe_network_ = OVCore::Get()->CompileModel(model, - hw_target, - device_config, - subgraph_context_.subgraph_name); - } else { // For all other types use ov::ov_core read_model() to generate OV IR - // followed by ov::ov_core compile_model() - std::string model = model_proto->SerializeAsString(); - // Reset model proto only when cpu fallback is disabled or when the model has dynamic input shapes. - // This is to avoid memory peak usage when the model is large. - if (!subgraph_context.has_dynamic_input_shape && disable_cpu_fallback) { - model_proto.reset(); + model_stream.reset(); + } catch (const char* msg) { + ORT_THROW(msg); + } // Delete stream after it is no longer needed + } else { + std::string model = model_proto->SerializeAsString(); + if (!subgraph_context.has_dynamic_input_shape) { + model_proto.reset(); + } + try { + // SetOVDeviceConfiguration(device_config); + if (!session_context_.has_external_weights && + !subgraph_context_.has_dynamic_input_shape && + !session_context_.so_context_enable && + session_context_.reshape.empty() && + !enable_causallm && + auto_unified_compile) { + // Unified OV compile_model is efficient when ov model caching is enabled + // Unified OV compile_model API is supported with AUTO from version 2024.3 and above + // Inputs with static dimensions + // Not enabled for models with external weights and when ep context is set. + + exe_network_ = OVCore::Get()->CompileModel(model, + hw_target, + device_config, + subgraph_context_.subgraph_name); + } else { // For all other types use ov::ov_core read_model() to generate OV IR + // followed by ov::ov_core compile_model() + auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); + exe_network_ = OVCore::Get()->CompileModel( + ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name); + } + LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; + } catch (const OnnxRuntimeException& ex) { + std::string exception_str = ex.what(); + bool eligible_for_cpu_fallback = session_context_.device_type.find("NPU") != std::string::npos && + !session_context_.so_disable_cpu_ep_fallback && + !subgraph_context_.is_ep_ctx_graph; +#if defined(OPENVINO_DISABLE_NPU_FALLBACK) + eligible_for_cpu_fallback = false; +#endif + if (eligible_for_cpu_fallback) { + LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." + << "Falling back to OV CPU for execution"; + session_context_.device_type = "CPU"; + session_context_.precision = "FP32"; + device_config.clear(); + SetOVDeviceConfiguration(device_config); + try { + // Recreate the model with CPU device type + auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); + exe_network_ = OVCore::Get()->CompileModel( + ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name); + } catch (std::string const& msg) { + ORT_THROW(msg); + } + } else { + ORT_THROW(ex.what()); } - auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); - exe_network_ = OVCore::Get()->CompileModel( - ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name); } - LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; - } catch (const char* msg) { - ORT_THROW(msg); } - int num_infer_req = (session_context_.num_of_threads > 0) ? session_context_.num_of_threads : 1; std::function initializer = [](OVInferRequestPtr) {}; auto metadata = shared_context_.shared_weights.metadata; @@ -385,6 +383,32 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) { device_config.emplace(ov::inference_num_threads(session_context_.num_of_threads)); } +void BasicBackend::SetOVDeviceConfiguration(ov::AnyMap& device_config) { + PopulateConfigValue(device_config); + + // Enable caching + EnableCaching(); + + // Setting OpenCL queue throttling for GPU + EnableGPUThrottling(device_config); + + // Enable streams; default=1 unless overridden by user configuration + EnableStreams(); + + // Set the inference_num_threads property of the CPU + SetNumThreads(device_config); + + auto npuw_status = + std::any_of(device_config.begin(), device_config.end(), [&](const std::pair& pair) { + return (pair.first.find("NPU_USE_NPUW") != std::string::npos) && (pair.second.is()) && + (pair.second.as() == "YES"); + }); + + if (npuw_status) { + LOGS_DEFAULT(INFO) << log_tag << "NPUW Enabled during compilation"; + } +} + void BasicBackend::ValidateOrtDimsAgainstPartialShape(const std::vector& ort_dims, const ov::PartialShape& partial_shape) const { // Check if the number of dimensions matches diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 9f0369a3603af..5c75a9ae183e2 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -146,6 +146,7 @@ class BasicBackend : public IBackend { void EnableGPUThrottling(ov::AnyMap& device_config); void EnableStreams(); void SetNumThreads(ov::AnyMap& device_config); + void SetOVDeviceConfiguration(ov::AnyMap& device_config); void ValidateOrtDimsAgainstPartialShape(const std::vector& ort_dims, const ov::PartialShape& partial_shape) const; From e2ec2b38696f535d0d16021a6a55b1b7bbc73718 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Thu, 3 Jul 2025 03:32:04 -0700 Subject: [PATCH 058/138] Add QDQ scale propagation pass (#713) * Add pass to perform QDQ stripping and propagate scales * Fix disconnected outptu node * Fixes to support session.disable_quant_qdq output, remove dangling nodes and duplicate DQ nodes * Fix lack of scales updates and remove stray QDQ nodes in certain models * Address issues with Linux CI * Fix for double QDQ issue --- cmake/onnxruntime_providers_openvino.cmake | 2 +- .../optimizer/double_qdq_pairs_remover.cc | 1 + .../providers/openvino/backend_manager.cc | 15 +- .../providers/openvino/ov_protobuf_utils.cpp | 24 + .../providers/openvino/ov_protobuf_utils.h | 10 + .../qdq_transformations/qdq_scales_fix.cpp | 946 ++++++++++++++++++ .../qdq_transformations/qdq_scales_fix.h | 19 + 7 files changed, 1014 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/core/providers/openvino/ov_protobuf_utils.cpp create mode 100644 onnxruntime/core/providers/openvino/ov_protobuf_utils.h create mode 100644 onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp create mode 100644 onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index d7cb2d5ea0d0f..552f4cd3b8988 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -49,7 +49,7 @@ endif() add_dependencies(onnxruntime_providers_openvino onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) target_include_directories(onnxruntime_providers_openvino SYSTEM PUBLIC ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${OpenVINO_INCLUDE_DIR} ${OPENVINO_INCLUDE_DIR_LIST} ${PYTHON_INCLUDE_DIRS} $ENV{OPENCL_INCS} $ENV{OPENCL_INCS}/../../cl_headers/) - target_link_libraries(onnxruntime_providers_openvino ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${OPENVINO_LIB_LIST} ${ABSEIL_LIBS} Eigen3::Eigen) + target_link_libraries(onnxruntime_providers_openvino ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${OPENVINO_LIB_LIST} ${ABSEIL_LIBS} Eigen3::Eigen onnx_proto) target_compile_definitions(onnxruntime_providers_openvino PRIVATE FILE_NAME=\"onnxruntime_providers_openvino.dll\") diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc index 1841dfa2791e0..7f214e656e0ab 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc @@ -52,6 +52,7 @@ static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, input_init.ToProto(new_input_tensor); auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name()); new_input_tensor.set_name(new_name); + new_input_tensor.add_dims(1); NodeArg& new_input = graph_utils::AddInitializerWithExternalData(graph, new_input_tensor); graph_utils::ReplaceNodeInput(node, index, new_input); } diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 041d9c07e41fe..253bae3d92a36 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -20,6 +20,7 @@ #include "core/providers/openvino/ov_interface.h" #include "core/providers/openvino/ov_versions/capability.h" #include "core/providers/openvino/qdq_transformations/qdq_stripping.h" +#include "core/providers/openvino/qdq_transformations/qdq_scales_fix.h" namespace onnxruntime { namespace openvino_ep { @@ -429,8 +430,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, const auto& onnx_model_path_name = subgraph.ModelPath(); // QDQ stripping enabled only for the NPU and experimentally on the GPU - if ((session_context_.device_type.find("NPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos) && + if ((session_context_.device_type.find("NPU") != std::string::npos) && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights); @@ -440,6 +440,17 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; + } else if ((session_context_.device_type.find("GPU") != std::string::npos) && + enable_ovep_qdq_optimizer) { + // Create a copy of the model + std::unique_ptr model; + Status status = qdq_scales_fix::Transform(subgraph, logger, model); + auto model_proto = model->ToProto(); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + print_model_proto_duration(); + DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); + ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); + return model_proto; } else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled"; auto model = subgraph.CreateModel(logger); diff --git a/onnxruntime/core/providers/openvino/ov_protobuf_utils.cpp b/onnxruntime/core/providers/openvino/ov_protobuf_utils.cpp new file mode 100644 index 0000000000000..e28330e0bd433 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_protobuf_utils.cpp @@ -0,0 +1,24 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "ov_protobuf_utils.h" + +#include "core/graph/onnx_protobuf.h" +#include "core/common/common.h" + +namespace onnxruntime { +namespace openvino_ep { +float get_float_initializer_data(const void* initializer) { + const auto* tp = reinterpret_cast(initializer); + ORT_ENFORCE((tp->has_data_type() && (tp->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT))); + // ORT_ENFORCE(initializer.dims_size() == 1); + return tp->float_data(0); +} +void set_float_initializer_data(const void* initializer, float data) { + auto* tp = (ONNX_NAMESPACE::TensorProto*)(initializer); + ORT_ENFORCE((tp->has_data_type() && (tp->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT))); + // ORT_ENFORCE(initializer.dims_size() == 1); + tp->set_float_data(0, data); +} +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_protobuf_utils.h b/onnxruntime/core/providers/openvino/ov_protobuf_utils.h new file mode 100644 index 0000000000000..2a6d914ee2920 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_protobuf_utils.h @@ -0,0 +1,10 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once +namespace onnxruntime { +namespace openvino_ep { +float get_float_initializer_data(const void* initializer); +void set_float_initializer_data(const void* initializer, float data); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp new file mode 100644 index 0000000000000..571aa57c99f33 --- /dev/null +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -0,0 +1,946 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "qdq_scales_fix.h" +#include "core/providers/openvino/ov_protobuf_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace openvino_ep { + +namespace qdq_scales_fix { + +namespace fs = std::filesystem; +using NodeRef = std::reference_wrapper; +struct GraphNode; +float get_initializer_value(const Graph& graph, const std::string& initializer_name); + +template +bool contains(V&& begin, V&& end, const T& val) { + for (V& iter = begin; iter != end; iter.operator++()) { + if (iter->Name() == val) { + return true; + } + } + return false; +} + +template +bool contains(const R& vec, const T& val) { + for (auto iter = vec.begin(); iter != vec.end(); iter++) { + if ((*iter)->Name() == val) { + return true; + } + } + return false; +} + +bool contains(const std::vector& container, const std::string& value) { + return std::find(container.begin(), container.end(), value) != container.end(); +} + +struct GraphNode { + GraphNode() = delete; + + template + GraphNode(const N& node, const std::string& op_type = {}) { + node_name = node.Name(); + if constexpr (std::is_same_v) { + node_ptr = &node; + this->op_type = node.OpType(); + for (const auto iter : node.InputDefs()) { + node_input_name.push_back(iter->Name()); + } + for (const auto iter : node.OutputDefs()) { + node_output_name.push_back(iter->Name()); + } + } else { + this->op_type = op_type; + //** node_input_name = [] + //** node_output_name = [] + } + + if (op_type == "output") { + down_to_output = true; + } + } + + bool operator==(const GraphNode&) const = default; + + void add_edge_to(GraphNode& dst_node) { + to_node.push_back(&dst_node); + } + + void add_edge_from(GraphNode& src_node) { + from_node.push_back(&src_node); + } + + std::vector apply_scale_to_graph(float scale_adj) { + std::vector affected_dq; + + auto extend = [&affected_dq, scale_adj](const std::vector& new_nodes) { + affected_dq.insert(affected_dq.end(), new_nodes.begin(), new_nodes.end()); + }; + + if (op_type == "DequantizeLinear") { + scale_factor *= scale_adj; + affected_dq.push_back(this); + } else if ((op_type == "Add") || (op_type == "QuantizeLinear")) { + for (auto node : from_node) { + extend(node->apply_scale_to_graph(scale_adj)); + } + } else if (op_type == "Conv") { + // just adjust w&b for conv&mul, stop propagate + for (auto node : from_node) { + if (node->op_type == "DequantizeLinear") { + extend(node->apply_scale_to_graph(scale_adj)); + } + } + } else if ((op_type == "Mul") || (op_type == "MatMul")) { + bool find_dq{false}; + for (auto node : from_node) { + if (node->op_type == "DequantizeLinear" && !find_dq) { + find_dq = true; + extend(node->apply_scale_to_graph(scale_adj)); + } + } + if (!find_dq) { + // cannot scale dq from here, choose input 0 to propagate + extend(from_node.back()->from_node[0]->apply_scale_to_graph(scale_adj)); + } + } else { + ORT_THROW("Unknown case, node: %s", ToString().data()); + } + + return affected_dq; + } + + std::vector down_propagate_scale() { + std::vector affected_nodes; + + if (processed) { + return affected_nodes; + } + + if ((op_type == "InstanceNormalization") || (op_type == "Softmax")) { + // pass + } else if (op_type == "Add") { + auto up_new_nodes = up_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), up_new_nodes.begin(), up_new_nodes.end()); + + for (auto node : to_node) { + auto down_new_nodes = node->down_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), down_new_nodes.begin(), down_new_nodes.end()); + } + } else { + affected_nodes.push_back(this); + processed = true; + + for (auto node : to_node) { + auto new_nodes = node->down_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), new_nodes.begin(), new_nodes.end()); + } + } + return affected_nodes; + } + + std::vector up_propagate_scale() { + std::vector affected_nodes; + + if (processed) { + return affected_nodes; + } + + if ((op_type == "InstanceNormalization") || (op_type == "Softmax")) { + ORT_THROW("Cannot propagate up through norm layer"); + } else if (op_type == "Conv") { + affected_nodes.push_back(this); + processed = true; + + for (auto node : from_node) { + if (node->op_type == "DequantizeLinear") { + affected_nodes.push_back(node); + } + } + } else if ((op_type == "Mul") || (op_type == "MatMul")) { + affected_nodes.push_back(this); + processed = true; + bool find_dq{false}; + + for (auto node : from_node) { + if ((node->op_type == "DequantizeLinear") && !find_dq) { + find_dq = true; + affected_nodes.push_back(node); + } + } + if (!find_dq) { + auto new_nodes = from_node.back()->from_node[0]->up_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), new_nodes.begin(), new_nodes.end()); + } + } else { + affected_nodes.push_back(this); + processed = true; + + for (auto node : from_node) { + auto new_nodes = node->up_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), new_nodes.begin(), new_nodes.end()); + } + } + + return affected_nodes; + } + + bool down_propagate_to_output() { + if (down_to_output.has_value()) { + return down_to_output.value(); + } + + bool local_down_to_output{false}; + if (op_type == "output") { + local_down_to_output = true; + } else if ((op_type == "InstanceNormalization") || (op_type == "Softmax")) { + local_down_to_output = false; + } else { + for (auto node : to_node) { + local_down_to_output = local_down_to_output || node->down_propagate_to_output(); + } + } + + down_to_output = local_down_to_output; + return local_down_to_output; + } + + std::string ToString() const { + // auto string = std::format("op={} name={} queued={} visited={} scale_factor={}", + // op_type, + // node_name, + // queued, + // visited, + // scale_factor); + auto print_node_vector = [](const std::vector& nodes) -> std::string { + // auto comp = [](const GraphNode* left, const GraphNode* right) -> bool { + // return left->node_name < right->node_name; + // }; + // std::sort(nodes.begin(), nodes.end(), comp); + std::string ret = "["; + for (size_t i = 0, size = nodes.size(); auto pnode : nodes) { + if (pnode->node_name.size() == 0) continue; + ret += pnode->node_name; + if (++i < size) { + ret += ", "; + } + } + ret += "]"; + return ret; + }; + std::string from_node_str = print_node_vector(from_node); + std::string to_node_str = print_node_vector(to_node); + + auto print_string_vector = [](const std::vector& nodes) -> std::string { + // std::sort(nodes.begin(), nodes.end()); + std::string ret = "["; + for (size_t i = 0, size = nodes.size(); const auto& node : nodes) { + ret += node; + if (++i < size) { + ret += ", "; + } + } + ret += "]"; + return ret; + }; + std::string node_input_name_str = print_string_vector(node_input_name); + std::string node_output_name_str = print_string_vector(node_output_name); + + auto print_bool = [](bool val) -> std::string { + return (val) ? "True" : "False"; + }; + + auto print_opt_bool = [print_bool](std::optional val) -> std::string { + return (val.has_value()) ? print_bool(val.value()) : "None"; + }; + + auto string = std::format("node_name={} op_type={} scale_factor={:.2f} visited={} queued={} down_to_output={} processed={} from_node={} to_node={} node_input_name={} node_output_name={}", + node_name, + op_type, + scale_factor, + visited, + print_bool(queued), + print_opt_bool(down_to_output), + print_bool(processed), + from_node_str, + to_node_str, + node_input_name_str, + node_output_name_str); + return string; + } + + const Node* node_ptr{nullptr}; + std::string node_name; + std::string op_type; + std::vector node_input_name; + std::vector node_output_name; + std::vector from_node; + std::vector to_node; + float scale_factor{1.f}; + int visited{0}; + bool queued{false}; + std::optional down_to_output; + bool processed{false}; +}; + +struct CustomGraph { + CustomGraph() = delete; + CustomGraph(Graph& graph) : original_graph{graph} {} + + void sort() { + auto comp_node = [](const GraphNode& left, const GraphNode& right) -> bool { + return left.node_name < right.node_name; + }; + nodes.sort(comp_node); + + for (auto& node : nodes) { + auto comp_pnode = [](const GraphNode* left, const GraphNode* right) -> bool { + return left->node_name < right->node_name; + }; + std::sort(node.from_node.begin(), node.from_node.end(), comp_pnode); + std::sort(node.to_node.begin(), node.to_node.end(), comp_pnode); + } + } + + void add_node(const GraphNode& node) { + nodes.push_back(node); + } + + void add_edge(GraphNode& src, GraphNode& dst) { + src.add_edge_to(dst); + dst.add_edge_from(src); + } + + auto get_start_nodes() { + std::list start_nodes; + + for (auto& node : nodes) { + if (node.from_node.empty()) { + start_nodes.push_back(&node); + node.queued = true; + } + } + return start_nodes; + } + + void initailize_search(float threshold = 1.f, bool scale_output = false) { + remove_qdq(threshold, scale_output); + for (auto& node : nodes) { + node.visited = 0; + node.queued = false; + } + } + + void init_propagate() { + for (auto& node : nodes) { + node.processed = false; + } + } + + void remove_qdq_pair(const GraphNode& node, std::list& removed) { + auto& q = node; + InlinedVector dq_ptrs; + + for (auto& child : q.to_node) { + if (child->node_ptr && child->node_ptr->OpType() == "DequantizeLinear") { + dq_ptrs.push_back(child); + } + } + + if (dq_ptrs.empty()) { + return; + } + + for (std::size_t i = 1; i < dq_ptrs.size(); ++i) { + if (dq_ptrs[i]->node_input_name[1] != dq_ptrs[0]->node_input_name[1] || + dq_ptrs[i]->node_input_name[2] != dq_ptrs[0]->node_input_name[2]) { + return; + } + } + + auto& prev = *node.from_node[0]; + const auto& q_node = *q.node_ptr; + + bool is_prev_input = (prev.node_ptr == nullptr); + std::string prev_output_name = is_prev_input ? prev.node_name : prev.node_output_name[0]; + + InlinedVector> output_replacements; + for (auto dq_ptr : dq_ptrs) { + for (auto dst_node : dq_ptr->to_node) { + for (auto& scr_node : dst_node->from_node) { + if (*dq_ptr == *scr_node) { + scr_node = &prev; + } + } + + auto it = std::find(dst_node->node_input_name.begin(), dst_node->node_input_name.end(), dq_ptr->node_output_name[0]); + if (it != dst_node->node_input_name.end()) { + *it = prev_output_name; + } + } + for (auto& output : original_graph.GetOutputs()) { + if (output->Name() == dq_ptr->node_output_name[0]) { + const NodeArg* replacement_arg = nullptr; + if (!is_prev_input) { + replacement_arg = prev.node_ptr->OutputDefs()[0]; + } else { + replacement_arg = original_graph.GetNodeArg(prev.node_name); + ORT_ENFORCE(replacement_arg != nullptr, "Input not found: " + prev.node_name); + } + output_replacements.emplace_back(output, replacement_arg); + } + } + } + + prev.to_node.erase(std::remove(prev.to_node.begin(), prev.to_node.end(), &q), prev.to_node.end()); + for (auto dq_ptr : dq_ptrs) { + for (auto dst_node : dq_ptr->to_node) { + auto it = std::find(prev.to_node.begin(), prev.to_node.end(), dst_node); + if (it == prev.to_node.end()) { + prev.to_node.push_back(dst_node); + } + } + } + auto q_iter = std::find(nodes.begin(), nodes.end(), q); + if (q_iter != nodes.end()) { + removed.splice(removed.end(), nodes, q_iter); + } + + for (auto dq_ptr : dq_ptrs) { + auto dq_iter = std::find(nodes.begin(), nodes.end(), *dq_ptr); + if (dq_iter != nodes.end()) { + removed.splice(removed.end(), nodes, dq_iter); + } + } + + auto remove_edge = [this](const Node& src, const Node& dst, int src_arg, int dst_arg) { + original_graph.RemoveEdge(src.Index(), dst.Index(), src_arg, dst_arg); + }; + + auto in_edge = q_node.InputEdgesBegin(); + ORT_ENFORCE(in_edge != q_node.InputEdgesEnd(), "Q node must have an input edge"); + const int prev_output_index = in_edge->GetSrcArgIndex(); + + if (in_edge != q_node.InputEdgesEnd()) { + remove_edge(in_edge->GetNode(), q_node, + in_edge->GetSrcArgIndex(), in_edge->GetDstArgIndex()); + } + for (auto dq_ptr : dq_ptrs) { + auto& dq_node_ref = *dq_ptr->node_ptr; + + for (auto edge_it = dq_node_ref.InputEdgesBegin(); edge_it != dq_node_ref.InputEdgesEnd(); ++edge_it) { + if (edge_it->GetNode().Index() == q_node.Index()) { + remove_edge(edge_it->GetNode(), dq_node_ref, edge_it->GetSrcArgIndex(), edge_it->GetDstArgIndex()); + break; + } + } + + std::vector> output_edges; // (dst_node_index, src_arg, dst_arg) + for (auto out_edge_it = dq_node_ref.OutputEdgesBegin(); out_edge_it != dq_node_ref.OutputEdgesEnd(); ++out_edge_it) { + output_edges.emplace_back(out_edge_it->GetNode().Index(), + out_edge_it->GetSrcArgIndex(), + out_edge_it->GetDstArgIndex()); + } + + for (const auto& edge : output_edges) { + original_graph.RemoveEdge(dq_node_ref.Index(), std::get<0>(edge), + std::get<1>(edge), std::get<2>(edge)); + } + + if (!is_prev_input) { + for (const auto& edge : output_edges) { + original_graph.AddEdge(prev.node_ptr->Index(), + std::get<0>(edge), + prev_output_index, + std::get<2>(edge)); + } + } + } + + if (!output_replacements.empty()) { + auto outputs = original_graph.GetOutputs(); + for (auto& output : outputs) { + for (const auto& replacement : output_replacements) { + if (output == replacement.first) { + output = replacement.second; + break; + } + } + } + original_graph.SetOutputs(outputs); + } + + original_graph.RemoveNode(q_node.Index()); + for (auto dq_ptr : dq_ptrs) { + original_graph.RemoveNode(dq_ptr->node_ptr->Index()); + } + } + + std::list remove_qdq(float threshold = 1.f, bool scale_output = false) { + std::list removed; + std::vector nodes_copy; + std::for_each(nodes.begin(), nodes.end(), [&nodes_copy](GraphNode& node) { nodes_copy.push_back(&node); }); + for (auto node : nodes_copy) { + if (std::find(nodes.begin(), nodes.end(), *node) == nodes.end()) { + continue; + } + + if ((node->op_type == "QuantizeLinear") && + (node->to_node[0]->op_type == "DequantizeLinear")) { + const auto& zero_point_name = node->node_input_name[2]; + const auto p_initializer = original_graph.GetConstantInitializer(zero_point_name, false); + bool is_16_bit = p_initializer->has_data_type() && + (p_initializer->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT16 || + p_initializer->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT16); + if (!is_16_bit) + continue; + if (!scale_output && node->down_propagate_to_output()) { + remove_qdq_pair(*node, removed); + continue; + } + + auto scale_name = node->node_input_name[1]; // Scale + auto scale_value = get_initializer_value(original_graph, scale_name); + if (scale_value / node->scale_factor < threshold) { + remove_qdq_pair(*node, removed); + } + } + } + + // Reconnect graph outputs if disconnected + bool update_outputs{false}; + auto outputs = original_graph.GetOutputs(); + for (auto output : outputs) { + bool found{false}; + for (auto node : original_graph.Nodes()) { + if (contains(node->OutputNodesBegin(), node->OutputNodesEnd(), output->Name())) { + found = true; + break; + } + } + + if (!found) { + // Connect the last valid node to the graph output + for (auto node : std::ranges::reverse_view(original_graph.Nodes())) { + if (!node->OutputDefs().empty()) { + const auto& name = (*node->OutputDefs().begin())->Name(); + auto& node_arg = original_graph.GetOrCreateNodeArg(name, output->TypeAsProto()); + output = &node_arg; + update_outputs = true; + } + } + } + } + + if (update_outputs) { + original_graph.SetOutputs(outputs); + } + + return removed; + } + + void dump_custom_graph(fs::path path) { + if (auto file = std::ofstream(path)) { + std::vector node_ref; + for (auto& node : nodes) { + node_ref.emplace_back(&node); + } + + for (const auto& node : node_ref) { + std::string node_str = node->ToString(); + file << node_str << "\n"; + } + } + } + + std::list nodes; + std::list removed_nodes; + Graph& original_graph; +}; + +template +T* get_mutable_initializer_data(const Graph& graph, const std::string& name) { + auto initializer = graph.GetConstantInitializer(name, true); + if (!initializer) return nullptr; + + if constexpr (std::is_same_v) { + if (initializer->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) + return nullptr; + } + + return reinterpret_cast(const_cast(initializer->raw_data().data())); +} + +std::size_t get_initializer_size(const Graph& graph, const std::string& name) { + auto initializer = graph.GetConstantInitializer(name, true); + if (!initializer) return 0; + + std::size_t size = 1; + if (!initializer->dims_size()) + return 0; + for (int i = 0; i < initializer->dims_size(); ++i) { + size *= initializer->dims()[i]; + } + return size; +} + +float get_initializer_value(const Graph& graph, const std::string& initializer_name) { + const auto p_initializer = graph.GetConstantInitializer(initializer_name, false); + + if (p_initializer->has_raw_data()) { + auto raw_data = get_mutable_initializer_data(graph, initializer_name); + auto size = get_initializer_size(graph, initializer_name); + ORT_ENFORCE(size == 1, "Expected an initializer to be of size 1"); + return raw_data[0]; + } + else + return get_float_initializer_data(p_initializer); +} + +void update_initializer_value(Graph& graph, const std::string& initializer_name, const float new_value) { + const auto p_initializer = graph.GetConstantInitializer(initializer_name, false); + + if (p_initializer == nullptr) { + return; + } + + const auto& initializer = *p_initializer; + + // Verify 1D tensor + ORT_ENFORCE(initializer.dims_size() == 1); + ORT_ENFORCE(initializer.data_type() == onnx::TensorProto_DataType_FLOAT); + + // Create new tensor with updated value + auto new_tensor = onnx::TensorProto::Create(); + new_tensor->copy_from(p_initializer); + *(float*)new_tensor->mutable_raw_data()->data() = new_value; + graph.RemoveInitializedTensor(initializer_name); + graph.AddInitializedTensor(*new_tensor); +} + +CustomGraph generate_graph_from_onnx(Graph& graph) { + CustomGraph gen_graph{graph}; + + for (auto pnode : graph.Nodes()) { + if (pnode->NodeType() == Node::Type::Fused) continue; + gen_graph.nodes.emplace_back(*pnode); + } + + for (auto& src_node : gen_graph.nodes) { + for (auto& dst_node : gen_graph.nodes) { + if (src_node == dst_node) { + continue; + } + + for (auto& src_output : src_node.node_output_name) { + if (contains(dst_node.node_input_name, src_output)) { + gen_graph.add_edge(src_node, dst_node); + } + } + } + } + + for (auto& input_node : graph.GetInputs()) { + auto& cur_input = gen_graph.nodes.emplace_back(*input_node, "input"); + for (auto& dst_node : gen_graph.nodes) { + for (const auto& dst_output : dst_node.node_input_name) { + if (dst_output == input_node->Name()) { + gen_graph.add_edge(cur_input, dst_node); + } + } + } + } + + for (auto& output_node : graph.GetOutputs()) { + auto& cur_output = gen_graph.nodes.emplace_back(*output_node, "output"); + for (auto& src_node : gen_graph.nodes) { + for (const auto& dst_outputs : src_node.node_output_name) { + if (dst_outputs == output_node->Name()) { + gen_graph.add_edge(src_node, cur_output); + } + } + } + } + + gen_graph.sort(); + return gen_graph; +} + +bool scale_graph(CustomGraph& gen_graph, + float threshold = 1.f, + float ratio = 10, + bool scale_output = false) { + bool needs_second_run = false; + gen_graph.initailize_search(threshold, scale_output); + auto q = gen_graph.get_start_nodes(); + auto pred = [](const GraphNode* left, const GraphNode* right) -> bool { + return left->node_name < right->node_name; + }; + q.sort(pred); + + while (!q.empty()) { + auto cur_node = q.front(); + q.pop_front(); + if (static_cast(cur_node->visited) < cur_node->from_node.size()) { + cur_node->queued = false; + } else { + if (cur_node->op_type == "QuantizeLinear" && + cur_node->to_node[0]->op_type == "DequantizeLinear") { + needs_second_run = true; + auto scale_name = *std::next(cur_node->node_input_name.begin()); + auto scale_value = get_initializer_value(gen_graph.original_graph, scale_name); + + // QDQ pair with scale over 1 + if (scale_value / cur_node->scale_factor > threshold) { + gen_graph.init_propagate(); + // adjust previous op scale to threshold / 10 + auto scale_adj = scale_value / cur_node->scale_factor / threshold * ratio; + + // find related const dq to scale down + auto affected_dq = cur_node->apply_scale_to_graph(scale_adj); + std::vector affected_nodes; + + // then propage to graph to update scale + for (auto& dq : affected_dq) { + auto cur_affected = dq->down_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), cur_affected.begin(), cur_affected.end()); + } + + for (auto& node : affected_nodes) { + bool found = std::find(affected_dq.begin(), affected_dq.end(), node) != affected_dq.end(); + if (!found) { + node->scale_factor *= scale_adj; + } + } + + auto removed_qdq = gen_graph.remove_qdq(threshold, scale_output); + for (auto& qdq : removed_qdq) { + try { + q.remove(&qdq); + } catch (...) { + } + } + + gen_graph.removed_nodes.splice(gen_graph.removed_nodes.end(), removed_qdq); + + cur_node = cur_node->to_node[0]; + } + } + + for (auto dst : cur_node->to_node) { + dst->visited += 1; + if (!dst->queued) { + dst->queued = true; + q.push_back(dst); + } + } + } + } + + for (auto& node : gen_graph.nodes) { + if (node.op_type == "DequantizeLinear" && node.scale_factor != 1.0f) { + const auto& scale_name = node.node_input_name[1]; + + auto scale_data = get_mutable_initializer_data(gen_graph.original_graph, scale_name); + if (scale_data) { + const auto scale_size = get_initializer_size(gen_graph.original_graph, scale_name); + if (!scale_size) { + auto it = gen_graph.original_graph.GetConstantInitializer(scale_name, true); + auto cur_scale = get_float_initializer_data(it); + cur_scale /= node.scale_factor; + set_float_initializer_data(it, cur_scale); + } else { + for (std::size_t i = 0; i < scale_size; ++i) { + scale_data[i] /= node.scale_factor; + } + } + } + + node.scale_factor = 1.0f; + } + } + return needs_second_run; +} + + +Status copy_model(const GraphViewer& src_graph_viewer, + const logging::Logger& logger, std::unique_ptr& model) { + model = src_graph_viewer.CreateModel(logger); + const auto& src_graph = src_graph_viewer.GetGraph(); + auto& dst_graph = model->MainGraph(); + + const auto& inputs = src_graph.GetInputs(); + const auto& outputs = src_graph.GetOutputs(); + + struct InputReplacement { + NodeArg* graph_input; + NodeArg* identity_output; + }; + std::unordered_map input_replacement_map; + + struct OutputReplacement { + NodeArg* intermediate_arg; + NodeArg* original_output; + }; + std::unordered_map output_replacement_map; + + InlinedVector dst_graph_inputs; + dst_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + const auto& input_name = input->Name(); + auto input_arg = src_graph.GetNodeArg(input_name); + + auto& dst_input_arg = dst_graph.GetOrCreateNodeArg(input_name, input_arg->TypeAsProto()); + dst_graph_inputs.push_back(&dst_input_arg); + + auto output_name = input_name + "_identity_output"; + auto& identity_output_arg = dst_graph.GetOrCreateNodeArg(output_name, input_arg->TypeAsProto()); + + input_replacement_map[input_name] = {&dst_input_arg, &identity_output_arg}; + } + + InlinedVector dst_graph_outputs; + for (auto& output : outputs) { + const auto& output_name = output->Name(); + auto output_arg = src_graph.GetNodeArg(output_name); + + std::string intermediate_name = "tmp_" + output_name; + auto& intermediate_out = dst_graph.GetOrCreateNodeArg(intermediate_name, output_arg->TypeAsProto()); + + auto& original_out = dst_graph.GetOrCreateNodeArg(output_name, output_arg->TypeAsProto()); + + output_replacement_map[output_name] = {&intermediate_out, &original_out}; + dst_graph_outputs.push_back(&original_out); + } + + dst_graph.SetInputs(dst_graph_inputs); + dst_graph.SetOutputs(dst_graph_outputs); + dst_graph.SetName(src_graph.Name()); + + for (const auto& name : src_graph_viewer.GetOuterScopeNodeArgNames()) { + auto node_arg = src_graph.GetNodeArg(name); + ORT_RETURN_IF_NOT(node_arg != nullptr, "Outer scope node arg name '" + name + "'was added but does not exist. "); + dst_graph.AddOuterScopeNodeArg(name); + } + + for (auto& input : inputs) { + const auto& input_name = input->Name(); + auto it = input_replacement_map.find(input_name); + ORT_RETURN_IF_NOT(it != input_replacement_map.end(), "Missing replacement for input: " + input_name); + + InputReplacement& repl = it->second; + InlinedVector input_args = {repl.graph_input}; + InlinedVector output_args = {repl.identity_output}; + + std::string node_name = "IdentityInsertion_" + input_name; + dst_graph.AddNode(node_name, "Identity", "Inserted identity node", + input_args, output_args, + nullptr, ""); + } + + for (auto pnode : src_graph.Nodes()) { + if (pnode->NodeType() == Node::Type::Fused) continue; + + InlinedVector new_input_args; + for (auto input_arg : pnode->InputDefs()) { + if (!input_arg) { + new_input_args.push_back(nullptr); + continue; + } + + auto it = input_replacement_map.find(input_arg->Name()); + if (it != input_replacement_map.end()) { + new_input_args.push_back(it->second.identity_output); + } else { + auto& new_arg = dst_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + new_input_args.push_back(&new_arg); + } + } + InlinedVector new_output_args; + for (auto output_arg : pnode->OutputDefs()) { + if (output_arg == nullptr) { + new_output_args.push_back(nullptr); + continue; + } + + auto it_output = output_replacement_map.find(output_arg->Name()); + if (it_output != output_replacement_map.end()) { + new_output_args.push_back(it_output->second.intermediate_arg); + } else { + auto& new_arg = dst_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + new_output_args.push_back(&new_arg); + } + } + + dst_graph.AddNode(pnode->Name(), pnode->OpType(), pnode->Description(), + new_input_args, new_output_args, + &pnode->GetAttributes(), pnode->Domain()); + } + + for (auto& output : outputs) { + const std::string& output_name = output->Name(); + auto it = output_replacement_map.find(output_name); + if (it == output_replacement_map.end()) continue; + + OutputReplacement& repl = it->second; + InlinedVector input_args = {repl.intermediate_arg}; + InlinedVector output_args = {repl.original_output}; + + std::string node_name = "IdentityInsertion_" + output_name; + dst_graph.AddNode(node_name, "Identity", "Inserted identitynode", + input_args, output_args, nullptr, ""); + } + + for (auto& [name, tensor_proto] : src_graph.GetAllInitializedTensors()) { + dst_graph.AddInitializedTensor(*tensor_proto); + } + + for (auto node_arg : src_graph.GetInputsIncludingInitializers()) { + auto check_inputs = [node_arg](auto input_node_arg) { + return input_node_arg->Name() == node_arg->Name(); + }; + if (std::find_if(dst_graph_inputs.begin(), dst_graph_inputs.end(), check_inputs) != dst_graph_inputs.end()) + continue; + + auto src_tensor_proto = src_graph.GetConstantInitializer(node_arg->Name(), true); + if (src_tensor_proto) { + auto dst_tensor_proto = onnx::TensorProto::Create(); + dst_tensor_proto->copy_from(src_tensor_proto); + dst_graph.AddInitializedTensor(*dst_tensor_proto); + } + } + + ORT_RETURN_IF_ERROR(dst_graph.Resolve()); + return Status::OK(); +} + +Status Transform(const GraphViewer& src_graph_viewer, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model) { + auto status = copy_model(src_graph_viewer, logger, model); + auto g = generate_graph_from_onnx(model->MainGraph()); + + float threshold{1.f}; + float ratio{10.f}; + bool scale_output{false}; + auto needs_second_run = scale_graph(g, threshold, ratio, scale_output); + if (needs_second_run) + scale_graph(g, threshold * 100, ratio, scale_output); + return status; +} +} // namespace qdq_scales_fix +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h new file mode 100644 index 0000000000000..c54c531e1bd40 --- /dev/null +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h @@ -0,0 +1,19 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +class GraphViewer; + +namespace openvino_ep { + +namespace qdq_scales_fix { +Status Transform(const GraphViewer& src_graph, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model); +} +} // namespace openvino_ep +} // namespace onnxruntime From 3238a90233b7a6427a4fca1ae658688315951684 Mon Sep 17 00:00:00 2001 From: sfatimar Date: Thu, 3 Jul 2025 17:55:46 +0530 Subject: [PATCH 059/138] Added to remove MAC CI Warnings (#726) Co-authored-by: TejalKhade28 Co-authored-by: Ankit Maheshkar --- onnxruntime/test/perftest/performance_runner.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc index 8ec9694227c14..24a2bd633b96c 100644 --- a/onnxruntime/test/perftest/performance_runner.cc +++ b/onnxruntime/test/perftest/performance_runner.cc @@ -218,6 +218,7 @@ Status PerformanceRunner::RunParallelDuration() { // Join tpool->Schedule([this, &counter, &m, &cv]() { + ORT_UNUSED_PARAMETER(this); std::unique_lock lock(m); cv.wait(lock, [&counter]() { return counter == 0; }); }); From 1695972e93205723eb7df4a294646b24c66a5724 Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Thu, 3 Jul 2025 20:57:28 +0530 Subject: [PATCH 060/138] fix: Undo perf_runner changes (#727) --- onnxruntime/test/perftest/performance_runner.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc index 24a2bd633b96c..faf0c34193717 100644 --- a/onnxruntime/test/perftest/performance_runner.cc +++ b/onnxruntime/test/perftest/performance_runner.cc @@ -203,9 +203,8 @@ Status PerformanceRunner::RunParallelDuration() { counter++; tpool->Schedule([this, &counter, &m, &cv]() { auto status = RunOneIteration(); - if (!status.IsOK()) { + if (!status.IsOK()) std::cerr << status.ErrorMessage(); - } // Simplified version of Eigen::Barrier std::lock_guard lg(m); counter--; @@ -217,11 +216,8 @@ Status PerformanceRunner::RunParallelDuration() { } while (duration_seconds.count() < performance_test_config_.run_config.duration_in_seconds); // Join - tpool->Schedule([this, &counter, &m, &cv]() { - ORT_UNUSED_PARAMETER(this); - std::unique_lock lock(m); - cv.wait(lock, [&counter]() { return counter == 0; }); - }); + std::unique_lock lock(m); + cv.wait(lock, [&counter]() { return counter == 0; }); return Status::OK(); } From 54151b1fb636242ca181362292afb7688b00e081 Mon Sep 17 00:00:00 2001 From: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Date: Tue, 8 Jul 2025 10:32:03 -0400 Subject: [PATCH 061/138] Enable dynamic path for NPU when enable_causallm is true (#732) Co-authored-by: Ankit Maheshkar --- onnxruntime/core/providers/openvino/backend_manager.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 253bae3d92a36..65532c31e14bd 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -118,7 +118,9 @@ BackendManager::BackendManager(SessionContext& session_context, LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; if ((!session_context_.disable_dynamic_shapes && (session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) || + session_context_.device_type.find("GPU") != std::string::npos || + (session_context_.device_type.find("NPU") != std::string::npos && + session_context_.enable_causallm) )) || (subgraph_context_.is_ep_ctx_graph)) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; From 8d36ad264a01e5071fab7bb544d9a29f41c539de Mon Sep 17 00:00:00 2001 From: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Date: Thu, 10 Jul 2025 02:13:12 -0400 Subject: [PATCH 062/138] Allow zero-element tensors to get set (#737) --- .../core/providers/openvino/ov_interface.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index fb1757199698b..a2067ce10485c 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -133,6 +133,22 @@ class OVInferRequest { auto tensor_ptr = std::make_shared(type, shape, const_cast(ort_ptr)); SetTensor(name, tensor_ptr); cached_binding = {tensor_ptr, ort_ptr}; + } else if (ort_ptr==nullptr) { + // a null ort_ptr is expected for a tensor that has 0 elements. + // for example, a tensor of shape=[1, 8, 0, 64], which is valid. + // So, we check to see if at least one shape entry is 0. + auto contains_zero = [](const ov::Shape& shape) { + for (auto& s : shape) + if (s == 0) return true; + return false; + }; + if (contains_zero(shape)) { + // if there are zero elements (i.e. at least one shape entry is 0), + // then create and set the tensor anyway. + auto tensor_ptr = std::make_shared(type, shape); + SetTensor(name, tensor_ptr); + cached_binding = {tensor_ptr, ort_ptr}; + } } } From 89ccd8175408e92679afb95d2c182cc88acbdbe4 Mon Sep 17 00:00:00 2001 From: sfatimar Date: Fri, 11 Jul 2025 22:36:00 +0530 Subject: [PATCH 063/138] Cluster Change to avoid Dangling DQLinear (#739) * Cluster Change to avoid Dangling DQLinear * Error in subgraph --------- Co-authored-by: TejalKhade28 --- .../openvino/ov_versions/capability.cc | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 88ddde8610c6e..85c7489fd75d6 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -168,17 +168,33 @@ std::vector> GetCapability::Execute() { auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters); int no_of_clusters = 0; + std::vector prev_cluster; + bool try_next_cluster = false; for (auto this_cluster : connected_clusters) { + bool omit_subgraph = false; + if (try_next_cluster) { + // no need to check previous cluster + for (auto idx : prev_cluster) { + if ((std::find(this_cluster.begin(), this_cluster.end(), idx)) == this_cluster.end()) { + this_cluster.emplace_back(idx); + } + } + try_next_cluster = false; + } + // If subgraph has less then three, graph is considered trivial unless its an epctx cluster - if (this_cluster.size() < 3) { + if (!try_next_cluster && this_cluster.size() < 3) { bool is_epctx_node = false; for (auto node_idx : this_cluster) { if (graph_viewer_.GetNode(node_idx)->OpType() == "EPContext") is_epctx_node = true; } - if (!is_epctx_node) - continue; + if (!is_epctx_node) { + omit_subgraph = true; + prev_cluster = this_cluster; + try_next_cluster = true; + } } std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; @@ -190,7 +206,7 @@ std::vector> GetCapability::Execute() { cluster_inputs, cluster_outputs); - bool omit_subgraph = false; + // Omitting zero dim subgraphs for (auto index : this_cluster) { const Node* node = graph_viewer_.GetNode(index); From 05126ff4c4c7e2bed14c5aa289bd633a80d7f96f Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Tue, 15 Jul 2025 22:12:01 -0700 Subject: [PATCH 064/138] Fix the model copies and redefinitions for CPU fallback (#728) * Fix the model copies and redefinitions for CPU fallback * OV compatibility is not needed --------- Co-authored-by: sfatimar --- .../openvino/backends/basic_backend.cc | 48 +++++++++---------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index df75f84a5fee0..61235ef2138b5 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -36,10 +36,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr if (ValidateSubgraph(const_outputs_map_)) return; - // Pre-requisite is provider_option "context" must be set - auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) || - (session_context_.OpenVINO_Version.at(0) >= 2024 && - session_context_.OpenVINO_Version.at(1) > 2)); ov::AnyMap device_config; SetOVDeviceConfiguration(device_config); if (subgraph_context_.is_ep_ctx_graph) { @@ -81,42 +77,46 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr ORT_THROW(msg); } // Delete stream after it is no longer needed } else { + std::shared_ptr ov_model; std::string model = model_proto->SerializeAsString(); if (!subgraph_context.has_dynamic_input_shape) { model_proto.reset(); } + bool eligible_for_cpu_fallback = session_context_.device_type.find("NPU") != std::string::npos && + !session_context_.so_disable_cpu_ep_fallback && + !subgraph_context_.is_ep_ctx_graph; +#if defined(OPENVINO_DISABLE_NPU_FALLBACK) + eligible_for_cpu_fallback = false; +#endif + auto auto_unified_compile = (hw_target.find("AUTO") == std::string::npos); + + // Unified compile is efficient with cahce_dir cached model loading that bypass Read Model + // Does not support model with exteral weights, dynamic input shape, Epctx onnx cached model, + // reshape, enable_causallm, and for NPU CPU fallback + + auto is_unified_compile = (!session_context_.has_external_weights && + !subgraph_context_.has_dynamic_input_shape && + !session_context_.so_context_enable && + session_context_.reshape.empty() && + !enable_causallm && + !eligible_for_cpu_fallback && + auto_unified_compile); try { - // SetOVDeviceConfiguration(device_config); - if (!session_context_.has_external_weights && - !subgraph_context_.has_dynamic_input_shape && - !session_context_.so_context_enable && - session_context_.reshape.empty() && - !enable_causallm && - auto_unified_compile) { - // Unified OV compile_model is efficient when ov model caching is enabled - // Unified OV compile_model API is supported with AUTO from version 2024.3 and above - // Inputs with static dimensions - // Not enabled for models with external weights and when ep context is set. - + if (is_unified_compile) { exe_network_ = OVCore::Get()->CompileModel(model, hw_target, device_config, subgraph_context_.subgraph_name); } else { // For all other types use ov::ov_core read_model() to generate OV IR // followed by ov::ov_core compile_model() - auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); + ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); exe_network_ = OVCore::Get()->CompileModel( ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name); } LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } catch (const OnnxRuntimeException& ex) { std::string exception_str = ex.what(); - bool eligible_for_cpu_fallback = session_context_.device_type.find("NPU") != std::string::npos && - !session_context_.so_disable_cpu_ep_fallback && - !subgraph_context_.is_ep_ctx_graph; -#if defined(OPENVINO_DISABLE_NPU_FALLBACK) - eligible_for_cpu_fallback = false; -#endif + if (eligible_for_cpu_fallback) { LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." << "Falling back to OV CPU for execution"; @@ -125,8 +125,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr device_config.clear(); SetOVDeviceConfiguration(device_config); try { - // Recreate the model with CPU device type - auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); exe_network_ = OVCore::Get()->CompileModel( ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name); } catch (std::string const& msg) { From 776bedf737c02500c0ab907d8d4ee9418e819050 Mon Sep 17 00:00:00 2001 From: sfatimar Date: Wed, 16 Jul 2025 22:43:21 +0530 Subject: [PATCH 065/138] Revert "Cluster Change to avoid Dangling DQLinear (#739)" (#743) This reverts commit 89ccd8175408e92679afb95d2c182cc88acbdbe4. --- .../openvino/ov_versions/capability.cc | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 85c7489fd75d6..88ddde8610c6e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -168,33 +168,17 @@ std::vector> GetCapability::Execute() { auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters); int no_of_clusters = 0; - std::vector prev_cluster; - bool try_next_cluster = false; for (auto this_cluster : connected_clusters) { - bool omit_subgraph = false; - if (try_next_cluster) { - // no need to check previous cluster - for (auto idx : prev_cluster) { - if ((std::find(this_cluster.begin(), this_cluster.end(), idx)) == this_cluster.end()) { - this_cluster.emplace_back(idx); - } - } - try_next_cluster = false; - } - // If subgraph has less then three, graph is considered trivial unless its an epctx cluster - if (!try_next_cluster && this_cluster.size() < 3) { + if (this_cluster.size() < 3) { bool is_epctx_node = false; for (auto node_idx : this_cluster) { if (graph_viewer_.GetNode(node_idx)->OpType() == "EPContext") is_epctx_node = true; } - if (!is_epctx_node) { - omit_subgraph = true; - prev_cluster = this_cluster; - try_next_cluster = true; - } + if (!is_epctx_node) + continue; } std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; @@ -206,7 +190,7 @@ std::vector> GetCapability::Execute() { cluster_inputs, cluster_outputs); - + bool omit_subgraph = false; // Omitting zero dim subgraphs for (auto index : this_cluster) { const Node* node = graph_viewer_.GetNode(index); From bc3dc45c1a06cc514d393efad13a86df08f97f14 Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Wed, 16 Jul 2025 23:45:11 +0530 Subject: [PATCH 066/138] Sync ORT main 16 07 25 (#744) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [webgpu] Update wgsl_templates README.md (#25336) ### Description Fix a broken URL and numbering in the ordered list in README.md. ### Motivation and Context See Above. * [webgpu] Move the early return after copying for ScatterND (#25345) ### Description For ScatterND, if the indices are empty (nothing to update), it becomes a copy operation. So we should move the early return after copying. * [EP ABI] Utility to serialize OrtGraph to GraphProto (#25292) ### Description - Provides utility functions that serialize an `OrtGraph` to a `GraphProto` or `ModelProto`. - Header-only file that can be copied to a project that builds with ORT and ONNX. - Available in [include/onnxruntime/core/providers/utils/ort_graph_to_proto.h](https://github.com/microsoft/onnxruntime/blob/adrianl/ep-abi-ort-graph-to-onnx-protobuf/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h) - Updates the `Node_GetSubgraphs` API function to also return the attribute names associated with each subgraph. This is required to determine which subgraph corresponds to a given attribute. - Adds `Graph_GetNumOperatorSets` and `Graph_GetOperatorSets` API functions to get the opset version for each domain. ### Motivation and Context Provide a utility to facilitate porting of existing execution providers to the new EP ABI. The utilities introduced by this PR convert an `OrtGraph` into an ONNX protobuf representation, which some existing EPs currently convert to their internal representation. Ideally, we would prefer a more direct conversion from a `OrtGraph` to the EP's internal representation, but this is a large effort. These utilities enable an incremental transition. * Update vcpkg.json: remove optional-lite (#25339) The library is not used. C++ itself already has std::optional. * Move buffer release or cache from OnRefresh to ReleaseBuffer in BucketCacheManager (#25276) ### Description This PR is to move buffer release or cache from OnRefresh to ReleaseBuffer in BucketCacheManager. ### Motivation and Context The OnRefresh is executed after a batch(16) ep runs and inside the batch runs, the buffer can not be really reused which is a waste for gpu buffer resources. This PR proposed a strightforward optimization that release or cache the buffer early in ReleaseBuffer instead of OnRefresh to improve the buffer cache or release efficiency which will improve the peak and average GPU memory usage. The experimental result also shows a reasonable memory optimization without perf regressions. #### Phi3 Optimization Strategy | Peak Memory (MB) | Avg Memory (MB) | Token Gen Latency (ms) | Tokens/sec -- | -- | -- | -- | -- Default Bucket | 3603.83 | 3127.05 | 7.17 | 139.50 Default Bucket with Early Release Optimization | 3534.77 (+1.92%) | 3073.97 (+1.70%) | 7.14 (+0.36%) | 140.01 (+0.36%) #### Deepseek-R1 Optimization Strategy | Peak Memory (MB) | Avg Memory (MB) | Token Gen Latency (ms) | Tokens/sec -- | -- | -- | -- | -- Default Bucket | 2089.03 | 1716.15 | 6.07 | 164.67 Default Bucket with Early Release Optimization | 2034.00 (+2.63%) | 1674.49 (+2.43%) | 6.09 (-0.20%) | 164.34 (-0.20%) #### LLama3.2-1B Optimization Strategy | Peak Memory (MB) | Avg Memory (MB) | Token Gen Latency (ms) | Tokens/sec -- | -- | -- | -- | -- Default Bucket | 1736.03 | 1424.64 | 3.37 | 296.53 Default Bucket with Early Release Optimization | 1659.78 (+4.39%) | 1366.78 (+4.06%) | 3.41 (-1.09%) | 293.34 (-1.08%) * [web] Fix "npm run pull:wasm" script (#25330) ### Description following up for #25267 * [MLAS] DequantizeLinear int8/uint8 (#24818) ### Description - Adds multithreaded vectorized implementations of DequantizeLinear for int8 and uint8 inputs: - Intel SSE 2 - ARM NEON - All other architectures fallback to a multithreaded scalar reference implementation (previous was not multithreaded). - **Note**: only enabled if ORT is built for client/on-device workloads (`ORT_CLIENT_PACKAGE_BUILD` is defined). INT8 DequantizeLinear latency on Intel Core i9-10920X with 4 intra op threads (SSE 2 implementation) | Number of elements | Baseline latency (us) | Multithreaded+SIMD latency (us) | Speedup | | ----------------------- | ---------------------- | ------------------------------------ | ---------- | | 10 K | 1 | 1 | 1 | | 20 K | 2 | 2 | 1 | | 40 K | 5 | 5 | 1 | | 80 K | 11 | 4 | 2.75 | | 100 K | 14 | 5 | 2.80 | | 150 K | 21 | 7 | 3.00 | | 200 K | 28 | 8 | 3.50 | | 400 K | 68 | 15 | 4.53 | | 600 K | 107 | 21 | 5.10 | | 800 K | 142 | 28 | 5.07 | | 1 M | 187 | 42 | 4.45 | | 2 M | 376 | 102 | 3.69 | | 4 M | 880 | 236 | 3.73 | | 6 M | 1547 | 557 | 2.78 | | 8 M | 2438 | 1097 | 2.22 | | 10 M | 3192 | 1464 | 2.18 | | 100 M | 38718 | 17733 | 2.18 | INT8 DequantizeLinear latency on Snapdragon 8cx gen 3 @ 3.4GHz with 4 intra op threads (NEON implementation) | Number of elements | Baseline latency (us) | Multithreaded+SIMD latency (us) | Speedup | | ----------------------- | ---------------------- | ------------------------------------ | ---------- | | 10 K | 1 | 1 | 1 | | 20 K | 1 | 1 | 1 | | 40 K | 3 | 3 | 1 | | 80 K | 7 | 4 | 1.75 | | 100 K | 9 | 3 | 3.00 | | 150 K | 14 | 5 | 2.80 | | 200 K | 18 | 6 | 3.00 | | 400 K | 38 | 10 | 3.80 | | 600 K | 61 | 15 | 4.07 | | 800 K | 76 | 19 | 4.00 | | 1 M | 98 | 24 | 4.08 | | 2 M | 204 | 48 | 4.25 | | 4 M | 424 | 112 | 3.79 | | 6 M | 677 | 384 | 1.76 | | 8 M | 919 | 621 | 1.48 | | 10 M | 1132 | 776 | 1.46 | | 100 M | 11842 | 10566 | 1.12 | ### Motivation and Context Improves latency of quantized QDQ models that with large DQs that dominate the inference latency. * [CPU] GQA supports head_sink input for smooth softmax (#25269) ### Description It is an extension of [Smooth Softmax](https://github.com/microsoft/onnxruntime/pull/21867) feature. The difference is that each head has a learnable smooth factor that adding to the denominator of softmax. The smooth factor is like an extra element that joins the softmax. The usage of the smooth factor in softmax is like the following: ```math softmax_{i} = \frac{exp(x_{i})}{exp(s)+ \sum_{j} exp(x_{j})} ``` The head_sink is a float tensor with length of number of attention heads. For h-th head, `head_sink[h]` is used as smooth factor s. When head_sink is not provided, constant 0 is used as smooth factor s. Changes: - [x] Update operator spec to add an optional new input `head_sink` - [x] Implement CPU (MLAS) kernel. - [x] Update test_gqa_cpu.py to test it. CUDA kernel will be updated later in a separate PR. * Add PackageVersion parameter to NuGet packaging stage (#25315) Fix: `Microsoft.ML.OnnxRuntime.Managed.nupkg` artifact from GPU pipeline does not have package version. ![image](https://github.com/user-attachments/assets/4a6135ab-4774-4aa6-aeb1-d5b06948ba8f) * [QNN EP] Fix pool with reshape name conflicts (#25332) Naming conflicts when expand-pool2d-squeeze (implemented as reshape) logic is invoked during ONNX -> QNN op lowering. Model with multiple pool 1D ops would hit this issue. * Added creation of QDQ for TopK node (#25309) - Added TopK in registry.py so as to create QDQ nodes for the op - Ensure that both the input and output quantization params are equal - Added unit test to verify the creation of QDQ nodes for TopK ### Description: Added support for creation of QDQ nodes for TopK when quantized with ORT static quantization tool ### Motivation and Context: Currently there is support to form a node unit for TopK operator when QDQ nodes are present and both the input and output quantization params are equal. But there was no support to create QDQ nodes for TopK operator in the ORT static quantization tool * [WebNN] Refactor webnn op input rank check and add validation for ops (#25185) ### Description Development for webnn op input rank range check ### Motivation and Context - refactor webnn op input rank check - add validation for various ops - take `gemm` op as an example to perform inputs rank check of decomposed ops @honry @fdwr PTAL * Make TRT plugins optional (#25261) ### Description The parser does no longer link agains the plugin library but also loads it dynamic. Due to that I think we should also make the library optional in ORT. @chilo-ms * [EP ABI] Add Graph_GetGraphView API to get a OrtGraph from a subset of nodes (#25191) Added an API that creates a sub-graph from a set of nodes in an OrtGraph. This API is needed in the GetCapability EP ABI porting when EP wants to check whether a 'sub-graph' of the graph is supported by the hardware backend. * [webgpu] a few optimization to WGSL template (#25333) ### Description This change is a follow up to #25130. - consume duktape from vcpkg if --use_vcpkg is specified - ~~add a Windows CI pipeline for dynamic WGSL template~~ (Will do in a separate PR) - upgrade wgsl-template package from 0.1.10 to 0.1.13 - support adding contribop folder as input * add --client_package_build option (#25351) add a build option to enable default options more appropriate for client/on-device workloads. initial use case will be to set the default thread pool allow_spinning policy , which we want to default to 0/false for builds targeted for client/on-device workloads. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * [WebNN] Fix bug in Float16Array availability check (#25354) The `from` is not a property of `Float16Array` but an inherited function, we can use `Float16Array['from']` to check if it is available. * [EP ABI] Add Node_GetEpType API (#25350) Add a new API `Node_GetEpType` to get the EP that the node is assigned to run on. This API is needed when porting the plugin TRT EP in `GetCapability` where ep needs to know whether the subgraph(s) of the control flow node is assigned to the ep and then to add this control flow op to the support list. * QNN-EP: DSPQueue Polling (#25361) ### Description Enable DSP queue polling when performance profile is burst * [QNN_EP] Implement Efficient Mode API (#25146) ### Description - Set context priority to low when workload type is Efficient - Set context priority to command line configured value if Default - Error out otherwise (invalid argument) * Add Compile API to set the location for the context binary file (#25356) Add Compile API ModelCompilationOptions_SetEpContextBinaryInformation to set the folder path and model name so that the EP can get the right place to dump the [model_name]_[ep].bin file. * add build matrix for wgsl template (#25352) ### Description Windows WebGPU CI: add build matrix for wgsl template * [JSEP] Fix inputShape index OOB in slice.ts (#25364) Use `inputShape.length - 1` instead of `inputShape.length` to avoid out-of-bounds access. * [webgpu] extend cast version to 23 (#25235) * Fix a security warning (#18979) Description (reference: https://github.com/advisories/GHSA-5crp-9r3c-p9vr) Newtonsoft.Json prior to version 13.0.1 is vulnerable to Insecure Defaults due to improper handling of expressions with high nesting level that lead to StackOverFlow exception or high CPU and RAM usage. Exploiting this vulnerability results in Denial Of Service (DoS). To mitigate the issue one either need to update Newtonsoft.Json to 13.0.1 or set MaxDepth parameter in the JsonSerializerSettings. ``` JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; ``` This file is the only place using `JsonConvert`, so I blindly put this fix and hope the warning will disappear. * Fix AutoEpSelection and OrtEpLibrary tests when using AuthenticAMD (#24754) * Missing datatype in assertion (#23578) * [EP ABI] Update to use Node_GetEpName (#25363) Change to use `Node_GetEpName` API name to avoid confusion. For plugin EPs, the EP factory can use whatever name that registered with ORT, so make the API name `Node_GetEpName` to align with `OrtEpFactory.GetName.` * Bump clang-format from 20.1.7 to 20.1.8 (#25381) * Fix number of layers in Whisper export (#25375) ### Description This PR fixes the number of hidden layers used during the export of Whisper by always using the number of hidden layers in the decoder. ### Motivation and Context Most of the Whisper models contain the same number of hidden layers in the encoder and decoder. However, Whisper large v3 turbo contains 32 hidden layers in the encoder and only 4 hidden layers in the decoder. This PR also fixes [this issue](https://github.com/microsoft/onnxruntime-genai/issues/1611). * Bump transformers from 4.48.0 to 4.52.1 in /onnxruntime/python/tools/transformers/models/llama (#25328) Bumps [transformers](https://github.com/huggingface/transformers) from 4.48.0 to 4.52.1.
Release notes

Sourced from transformers's releases.

Patch release v4.51.3

A mix of bugs were fixed in this patch; very exceptionally, we diverge from semantic versioning to merge GLM-4 in this patch release.

  • Handle torch ver in flexattn (#37400)
  • handle torch version edge cases (#37399)
  • Add glm4 (#37388)

Patch Release 4.51.2

This is another round of bug fixes, but they are a lot more minor and outputs were not really affected!

Patch release v4.51.1

Since the release of Llama 4, we have fixed a few issues that we are now releasing in patch v4.51.1

  • Fixing flex attention for torch=2.6.0 (#37285)
  • more fixes for post-training llama4 (#37329)
  • Remove HQQ from caching allocator warmup (#37347)
  • fix derived berts _init_weights (#37341)
  • Fix init empty weights without accelerate (#37337)
  • Fix deepspeed with quantization (#37324)
  • fix llama4 training (#37319)
  • fix flex attn when optional args aren't passed (#37327)
  • Multiple llama4 fixe (#37353)

Thanks all for your patience

v4.51.0: Llama 4, Phi4-Multimodal, DeepSeek-v3, Qwen3

New Model Additions

Llama 4

image

Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.This generation includes two models:

  • The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
  • The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.

Both models leverage early fusion for native multimodality, enabling them to process text and image inputs. Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages (with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).

For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via on-the-fly 4-bit or 8-bit quantization, while Maverick is available in BF16 and FP8 formats. These models are released under the custom Llama 4 Community License Agreement, available on the model repositories

Getting started with Llama 4 using transformers is straightforward. Make sure you have transformers v4.51.0 or later installed:

pip install -U transformers[hf_xet]
</tr></table>

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=transformers&package-manager=pip&previous-version=4.48.0&new-version=4.52.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Bump ruff from 0.12.2 to 0.12.3 (#25382) Bumps [ruff](https://github.com/astral-sh/ruff) from 0.12.2 to 0.12.3.
Release notes

Sourced from ruff's releases.

0.12.3

Release Notes

Preview features

  • [flake8-bugbear] Support non-context-manager calls in B017 (#19063)
  • [flake8-use-pathlib] Add autofixes for PTH100, PTH106, PTH107, PTH108, PTH110, PTH111, PTH112, PTH113, PTH114, PTH115, PTH117, PTH119, PTH120 (#19213)
  • [flake8-use-pathlib] Add autofixes for PTH203, PTH204, PTH205 (#18922)

Bug fixes

  • [flake8-return] Fix false-positive for variables used inside nested functions in RET504 (#18433)
  • Treat form feed as valid whitespace before a line continuation (#19220)
  • [flake8-type-checking] Fix syntax error introduced by fix (TC008) (#19150)
  • [pyupgrade] Keyword arguments in super should suppress the UP008 fix (#19131)

Documentation

  • [flake8-pyi] Make example error out-of-the-box (PYI007, PYI008) (#19103)
  • [flake8-simplify] Make example error out-of-the-box (SIM116) (#19111)
  • [flake8-type-checking] Make example error out-of-the-box (TC001) (#19151)
  • [flake8-use-pathlib] Make example error out-of-the-box (PTH210) (#19189)
  • [pycodestyle] Make example error out-of-the-box (E272) (#19191)
  • [pycodestyle] Make example not raise unnecessary SyntaxError (E114) (#19190)
  • [pydoclint] Make example error out-of-the-box (DOC501) (#19218)
  • [pylint, pyupgrade] Fix syntax errors in examples (PLW1501, UP028) (#19127)
  • [pylint] Update missing-maxsplit-arg docs and error to suggest proper usage (PLC0207) (#18949)
  • [flake8-bandit] Make example error out-of-the-box (S412) (#19241)

Contributors

... (truncated)

Changelog

Sourced from ruff's changelog.

0.12.3

Preview features

  • [flake8-bugbear] Support non-context-manager calls in B017 (#19063)
  • [flake8-use-pathlib] Add autofixes for PTH100, PTH106, PTH107, PTH108, PTH110, PTH111, PTH112, PTH113, PTH114, PTH115, PTH117, PTH119, PTH120 (#19213)
  • [flake8-use-pathlib] Add autofixes for PTH203, PTH204, PTH205 (#18922)

Bug fixes

  • [flake8-return] Fix false-positive for variables used inside nested functions in RET504 (#18433)
  • Treat form feed as valid whitespace before a line continuation (#19220)
  • [flake8-type-checking] Fix syntax error introduced by fix (TC008) (#19150)
  • [pyupgrade] Keyword arguments in super should suppress the UP008 fix (#19131)

Documentation

  • [flake8-pyi] Make example error out-of-the-box (PYI007, PYI008) (#19103)
  • [flake8-simplify] Make example error out-of-the-box (SIM116) (#19111)
  • [flake8-type-checking] Make example error out-of-the-box (TC001) (#19151)
  • [flake8-use-pathlib] Make example error out-of-the-box (PTH210) (#19189)
  • [pycodestyle] Make example error out-of-the-box (E272) (#19191)
  • [pycodestyle] Make example not raise unnecessary SyntaxError (E114) (#19190)
  • [pydoclint] Make example error out-of-the-box (DOC501) (#19218)
  • [pylint, pyupgrade] Fix syntax errors in examples (PLW1501, UP028) (#19127)
  • [pylint] Update missing-maxsplit-arg docs and error to suggest proper usage (PLC0207) (#18949)
  • [flake8-bandit] Make example error out-of-the-box (S412) (#19241)
Commits
  • 5bc81f2 Bump 0.12.3 (#19279)
  • 6908e26 Filter ruff_linter::VERSION out of SARIF output tests (#19280)
  • 25c4295 [ty] Avoid stale diagnostics for open files diagnostic mode (#19273)
  • 426fa4b [ty] Add signature help provider to playground (#19276)
  • b0b65c2 [ty] Initial implementation of signature help provider (#19194)
  • 08bc6d2 Add simple integration tests for all output formats (#19265)
  • f2ae12b [flake8-return] Fix false-positive for variables used inside nested functio...
  • 965f415 [ty] Add a --quiet mode (#19233)
  • 83b5bbf Treat form feed as valid whitespace before a line continuation (#19220)
  • 87f6f08 [ty] Make check_file a salsa query (#19255)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ruff&package-manager=pip&previous-version=0.12.2&new-version=0.12.3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) Dependabot will merge this PR once CI passes on it, as requested by @fs-eire. [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [QNN EP] Upgrade QNN to 2.36.1 (#25388) ### Description Update Qnn default version to 2.36.1.250708 Co-authored-by: Jeff Kilpatrick * Add vendor id to OrtEpFactory and default ORT logger to CreateEpFactories (#25365) ### Description Add vendor id to OrtEpFactory. It's easier to get the vendor id than name on other platforms. Update the selection policy to prefer match on vendor id with fallback to vendor name. Add default ORT logger to CreateEpFactories. The OrtEpFactory currently has no way to log informational messages or issues. CreateEp is given the session logger for use by the OrtEp instance so that part of things is good. Misc cleanups. Make usage of ORT_API2_STATUS and ORT_API_T consistent on onnxruntime_ep_c_api.h. See ort_version_supported in some EP factories where it was missed. ### Motivation and Context Vendor id is easier to match against OrtHardwareDevice when doing auto EP selection. OrtEpFactory should have a logger. Last chance to cleanup APIs before 1.23 release * Bump lintrunner-adapters from 0.12.4 to 0.12.5 (#25380) * [WebNN] Add rank range validation for rest ops (#25383) - Add common rank range validation to base_op_builder.cc - Handle specific rank range validation for rest ops - Remove duplicated input_shape validation - Fix some typos BTW * Fix some test issues when WebGPU and DML are enabled in the same build (#25401) ### Description Fix some test setups where both EPs being in the same build wasn't expected. ### Motivation and Context * Fix SigLIP casual mask bug (#25360) ### Description SigLIP architecture inside the vision encoder should not use a causal mask on the attention. This change will fix Phi 4 MM accuracy issues we have seen. ### Motivation and Context --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * [CPU] GQA supports attention scores output (#25319) ### Description 1. Add optional output to CPU impl of GQA op for storing attention scores (QK). Buffer is of shape (B, N, S, T) and can either be fp16 or fp32, depending on the type of other inputs 2. Add `qk_output` attribute to GQA, which controls if attention scores should be saved before or after softmax is applied 3. Add unit tests to cover this use case 4. Added asserts on other EPs if this feature is used * [QNN-EP] Support GridSample of linear mode for ONNX opset 20+ (#25408) [QNN-EP] Support GridSample of linear mode for ONNX opset 20+ * [QNN-EP] Update ScatterND op to reject only QNN-CPU (#25403) Current limitation is more than necessary -- only reject when targeting QNN CPU. * Fix 2 device discovery issues. (#25397) ### Description Fix vendor and device id conversion from SetupApi info. Detect Remote Display Adapter and skip. This results in a bogus device appearing when you're connected to a machine using remote desktop. ### Motivation and Context * [webgpu] fix Slice implementation (#25415) ### Description Bugfix: crash when dim_value is 0 ### Motivation and Context Thanks to @skottmckay who found the bug. --------- Signed-off-by: dependabot[bot] Co-authored-by: Jianhui Dai Co-authored-by: Jiajia Qin Co-authored-by: Adrian Lizarraga Co-authored-by: Changming Sun Co-authored-by: Fei Chen Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Tianlei Wu Co-authored-by: vraspar Co-authored-by: qti-yuduo Co-authored-by: Akupadhye Co-authored-by: Wang Ning Co-authored-by: Maximilian Müller <44298237+gedoensmax@users.noreply.github.com> Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: George Wu Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Wanming Lin Co-authored-by: quic-calvnguy Co-authored-by: Hector Li Co-authored-by: Jie Chen Co-authored-by: xhcao Co-authored-by: Wei-Sheng Chin Co-authored-by: quic-hungjuiw Co-authored-by: Ian Hunter Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Co-authored-by: Jeff Kilpatrick Co-authored-by: Jeff Kilpatrick Co-authored-by: Scott McKay Co-authored-by: Nenad Banfic <46795300+nenad1002@users.noreply.github.com> Co-authored-by: derdeljan-msft --- .github/workflows/windows_webgpu.yml | 2 + cmake/CMakeLists.txt | 1 + cmake/adjust_global_compile_flags.cmake | 5 + .../external/onnxruntime_external_deps.cmake | 25 +- cmake/onnxruntime_mlas.cmake | 1 + cmake/onnxruntime_providers_tensorrt.cmake | 23 +- cmake/onnxruntime_providers_webgpu.cmake | 11 +- cmake/vcpkg.json | 9 +- .../EndToEndTests.Mobile.Automation/Tests.cs | 4 +- .../TestResultProcessor.cs | 3 +- docs/ContribOperators.md | 10 +- docs/OperatorKernels.md | 6 +- include/onnxruntime/core/graph/graph.h | 5 +- .../core/providers/utils/ort_graph_to_proto.h | 718 ++++++++++++++++++ .../core/session/onnxruntime_c_api.h | 99 ++- .../core/session/onnxruntime_cxx_api.h | 2 + .../core/session/onnxruntime_cxx_inline.h | 9 + .../core/session/onnxruntime_ep_c_api.h | 90 ++- .../onnxruntime_session_options_config_keys.h | 4 +- js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 2 +- js/web/script/pull-prebuilt-wasm-artifacts.ts | 20 +- .../contrib_ops/cpu/bert/attention_common.h | 6 + .../contrib_ops/cpu/bert/attention_helper.h | 6 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 74 +- .../cpu/bert/group_query_attention.cc | 11 +- .../cpu/bert/group_query_attention_helper.h | 31 + .../cuda/bert/group_query_attention.cc | 6 + .../rocm/bert/group_query_attention.cu | 4 + .../webgpu/bert/group_query_attention.cc | 4 + onnxruntime/core/common/cpuid_info.cc | 2 +- onnxruntime/core/graph/abi_graph_types.h | 22 +- .../core/graph/contrib_ops/bert_defs.cc | 85 ++- onnxruntime/core/graph/ep_api_types.cc | 83 +- onnxruntime/core/graph/ep_api_types.h | 45 +- onnxruntime/core/graph/graph.cc | 4 + onnxruntime/core/graph/graph_viewer.cc | 12 +- .../core/graph/model_editor_api_types.h | 14 +- onnxruntime/core/mlas/inc/mlas.h | 16 + onnxruntime/core/mlas/lib/compute.cpp | 17 +- onnxruntime/core/mlas/lib/dequantize.cpp | 395 ++++++++++ onnxruntime/core/mlas/lib/mlasi.h | 22 + onnxruntime/core/mlas/lib/platform.cpp | 2 + .../core/platform/windows/device_discovery.cc | 79 +- .../core/providers/cpu/math/softmax_shared.cc | 2 +- onnxruntime/core/providers/cpu/ml/ml_common.h | 2 +- .../cpu/quantization/quantize_linear.cc | 98 ++- .../providers/cuda/cuda_provider_factory.cc | 8 + .../src/ExecutionProvider.cpp | 5 +- .../nv_tensorrt_rtx/nv_execution_provider.cc | 2 +- .../qnn/builder/opbuilder/pool_op_builder.cc | 42 +- .../builder/opbuilder/simple_op_builder.cc | 14 +- .../qnn/builder/qnn_backend_manager.cc | 50 +- .../qnn/builder/qnn_backend_manager.h | 10 +- .../providers/qnn/qnn_execution_provider.cc | 59 +- .../providers/qnn/qnn_execution_provider.h | 7 +- .../providers/qnn/qnn_provider_factory.cc | 14 +- .../tensorrt_execution_provider_custom_ops.cc | 44 +- .../core/providers/webgpu/buffer_manager.cc | 25 +- .../core/providers/webgpu/tensor/cast.cc | 20 +- .../providers/webgpu/tensor/scatter_nd.cc | 22 +- .../core/providers/webgpu/tensor/slice.cc | 4 +- .../webgpu/webgpu_execution_provider.cc | 8 +- .../providers/webgpu/wgsl_templates/README.md | 4 +- .../webgpu/wgsl_templates/package-lock.json | 8 +- .../webgpu/wgsl_templates/package.json | 2 +- .../core/providers/webnn/builders/helper.cc | 126 +-- .../core/providers/webnn/builders/helper.h | 34 + .../builders/impl/argmax_min_op_builder.cc | 18 - .../webnn/builders/impl/base_op_builder.cc | 7 +- .../webnn/builders/impl/binary_op_builder.cc | 5 +- .../webnn/builders/impl/concat_op_builder.cc | 3 +- .../webnn/builders/impl/conv_op_builder.cc | 2 +- .../webnn/builders/impl/cumsum_op_builder.cc | 4 - .../webnn/builders/impl/dropout_op_builder.cc | 20 +- .../webnn/builders/impl/einsum_op_builder.cc | 90 ++- .../impl/gatherElements_op_builder.cc | 6 +- .../builders/impl/gatherND_op_builder.cc | 6 +- .../webnn/builders/impl/gather_op_builder.cc | 28 +- .../webnn/builders/impl/gemm_op_builder.cc | 44 +- .../webnn/builders/impl/gru_op_builder.cc | 3 +- .../webnn/builders/impl/logical_op_builder.cc | 4 +- .../webnn/builders/impl/lrn_op_builder.cc | 15 +- .../webnn/builders/impl/lstm_op_builder.cc | 3 +- .../builders/impl/matMulNBits_op_builder.cc | 19 +- .../webnn/builders/impl/max_min_op_builder.cc | 24 +- .../builders/impl/normalization_op_builder.cc | 87 +-- .../webnn/builders/impl/pool_op_builder.cc | 14 - .../webnn/builders/impl/qdq_op_builder.cc | 3 +- .../builders/impl/reduction_op_builder.cc | 8 +- .../webnn/builders/impl/reshape_op_builder.cc | 5 - .../impl/rotaryEmbedding_op_builder.cc | 2 +- .../impl/scatterElements_op_builder.cc | 6 +- .../builders/impl/scatterND_op_builder.cc | 6 +- .../webnn/builders/impl/slice_op_builder.cc | 21 +- .../webnn/builders/impl/softmax_op_builder.cc | 19 - .../impl/squeeze_unsqueeze_op_builder.cc | 3 - .../webnn/builders/impl/ternary_op_builder.cc | 3 +- .../webnn/builders/impl/tile_op_builder.cc | 9 - .../builders/impl/triangular_op_builder.cc | 9 - .../core/providers/webnn/builders/map_info.h | 4 +- .../providers/webnn/builders/model_builder.h | 2 +- onnxruntime/core/session/compile_api.cc | 30 + onnxruntime/core/session/compile_api.h | 2 + onnxruntime/core/session/ep_api_utils.h | 4 + .../core/session/ep_factory_internal.cc | 4 +- .../core/session/ep_factory_internal.h | 4 +- .../core/session/ep_library_internal.cc | 9 +- .../session/ep_library_provider_bridge.cc | 1 + onnxruntime/core/session/inference_session.cc | 12 + .../core/session/model_compilation_options.cc | 36 +- .../core/session/model_compilation_options.h | 10 + onnxruntime/core/session/onnxruntime_c_api.cc | 134 +++- onnxruntime/core/session/ort_apis.h | 10 +- .../core/session/provider_policy_context.cc | 8 +- onnxruntime/core/util/qmath.h | 49 ++ onnxruntime/core/util/thread_utils.h | 6 + .../tools/quantization/base_quantizer.py | 2 +- .../python/tools/quantization/registry.py | 1 + .../transformers/fusion_attention_clip.py | 70 +- .../models/llama/requirements.txt | 2 +- .../models/whisper/convert_to_onnx.py | 2 +- .../models/whisper/requirements.txt | 2 +- .../models/whisper/whisper_decoder.py | 7 +- .../whisper/whisper_encoder_decoder_init.py | 4 +- .../models/whisper/whisper_helper.py | 4 +- .../models/whisper/whisper_inputs.py | 6 +- .../models/whisper/whisper_jump_times.py | 2 +- onnxruntime/test/autoep/library/ep.cc | 12 +- onnxruntime/test/autoep/library/ep.h | 6 +- onnxruntime/test/autoep/library/ep_factory.cc | 7 + onnxruntime/test/autoep/library/ep_factory.h | 2 + .../test/contrib_ops/matmul_4bits_test.cc | 6 +- onnxruntime/test/ep_graph/test_ep_graph.cc | 254 ++++++- .../test/ep_graph/test_ep_graph_utils.cc | 1 + .../test/ep_graph/test_ep_graph_utils.h | 1 + .../test/framework/ep_plugin_provider_test.cc | 14 +- .../test/mlas/bench/bench_computesoftmax.cpp | 4 +- .../mlas/unittest/test_dequantizelinear.cpp | 75 ++ .../test/mlas/unittest/test_softmax.cpp | 4 +- .../test/providers/cpu/math/softmax_test.cc | 3 +- .../cpu/tensor/quantize_linear_test.cc | 26 + .../cpu/tensor/scatter_nd_op_test.cc | 11 + .../test/providers/qnn/qnn_ep_context_test.cc | 80 +- .../test/providers/qnn/simple_op_htp_test.cc | 32 + .../test/python/quantization/test_op_topk.py | 103 +++ .../phi-4-v-instruct-vision-attention.onnx | Bin 0 -> 7729 bytes .../test/python/transformers/test_gqa_cpu.py | 559 +++++++++++--- .../test/python/transformers/test_gqa_cuda.py | 3 +- .../transformers/test_paged_attention_cuda.py | 3 +- .../python/transformers/test_phi_vision.py | 70 +- .../three_layer_nested_subgraph_v2.onnx | Bin 0 -> 1892 bytes requirements-lintrunner.txt | 6 +- tools/ci_build/build.py | 3 + tools/ci_build/build_args.py | 10 + ...arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- .../c-api-noopenmp-packaging-pipelines.yml | 2 +- .../custom-nuget-packaging-pipeline.yml | 2 +- .../azure-pipelines/linux-qnn-ci-pipeline.yml | 2 +- .../azure-pipelines/py-packaging-pipeline.yml | 2 +- .../qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../stages/nuget-cuda-packaging-stage.yml | 3 + .../stages/py-cpu-packaging-stage.yml | 2 +- .../templates/android-java-api-aar-test.yml | 2 +- .../templates/android-java-api-aar.yml | 2 +- .../azure-pipelines/templates/c-api-cpu.yml | 2 +- .../templates/jobs/download_linux_qnn_sdk.yml | 2 +- .../templates/jobs/download_win_qnn_sdk.yml | 2 +- .../templates/py-linux-qnn.yml | 2 +- .../templates/py-win-arm64-qnn.yml | 2 +- .../templates/py-win-arm64ec-qnn.yml | 2 +- .../templates/py-win-x64-qnn.yml | 2 +- .../azure-pipelines/templates/qnn-ep-win.yml | 4 +- .../win-qnn-arm64-ci-pipeline.yml | 4 +- .../azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- 174 files changed, 3985 insertions(+), 877 deletions(-) create mode 100644 include/onnxruntime/core/providers/utils/ort_graph_to_proto.h create mode 100644 onnxruntime/core/mlas/lib/dequantize.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp create mode 100644 onnxruntime/test/python/quantization/test_op_topk.py create mode 100644 onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx create mode 100644 onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 70e8ea7e2792f..996e0d816d51a 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -22,6 +22,7 @@ jobs: strategy: matrix: vcpkg_option: [novcpkg, vcpkg] + wgsl_template: [static, dynamic] env: OrtPackageId: Microsoft.ML.OnnxRuntime OnnxRuntimeBuildDirectory: ${{ github.workspace }} @@ -123,6 +124,7 @@ jobs: --build_nodejs ` --build_java ` --use_webgpu ` + --wgsl_template ${{ matrix.wgsl_template }} ` ${{ matrix.vcpkg_option == 'vcpkg' && '--use_vcpkg' || '' }} ` --cmake_extra_defines ` onnxruntime_BUILD_UNIT_TESTS=ON ` diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index fb4238731ffc3..b01110b2a4a03 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -151,6 +151,7 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) +option(onnxruntime_CLIENT_PACKAGE_BUILD "Enables default settings that are more appropriate for client/on-device workloads." OFF) cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 59d99ade131cd..6d517003fa6b6 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -95,6 +95,11 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() +# ORT build with default settings more appropriate for client/on-device workloads. +if (onnxruntime_CLIENT_PACKAGE_BUILD) + add_compile_definitions(ORT_CLIENT_PACKAGE_BUILD) +endif() + if (onnxruntime_ENABLE_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT ipo_enabled OUTPUT ipo_output) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e8f6bbe895d29..228906030d14c 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -774,13 +774,24 @@ if (onnxruntime_USE_WEBGPU) endif() if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - onnxruntime_fetchcontent_declare( - duktape - URL ${DEP_URL_duktape} - URL_HASH SHA1=${DEP_SHA1_duktape} - EXCLUDE_FROM_ALL - ) - onnxruntime_fetchcontent_makeavailable(duktape) + if(onnxruntime_USE_VCPKG) + find_package(unofficial-duktape CONFIG REQUIRED) + add_library(duktape_static ALIAS unofficial::duktape::duktape) + else() + onnxruntime_fetchcontent_declare( + duktape + URL ${DEP_URL_duktape} + URL_HASH SHA1=${DEP_SHA1_duktape} + EXCLUDE_FROM_ALL + ) + onnxruntime_fetchcontent_makeavailable(duktape) + + if(NOT TARGET duktape_static) + add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") + target_compile_features(duktape_static PRIVATE c_std_99) + target_include_directories(duktape_static INTERFACE $) + endif() + endif() endif() endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f8f5546ae9465..47e7779d93b33 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -31,6 +31,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp ${MLAS_SRC_DIR}/compute.cpp + ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp ${MLAS_SRC_DIR}/qladd.cpp diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 69c81a5ec7b9d..4184e0b049afc 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -72,10 +72,9 @@ endif() # TensorRT 10 GA onwards, the TensorRT libraries will have major version appended to the end on Windows, - # for example, nvinfer_10.dll, nvinfer_plugin_10.dll, nvonnxparser_10.dll ... + # for example, nvinfer_10.dll, nvonnxparser_10.dll ... if (WIN32 AND TRT_GREATER_OR_EQUAL_TRT_10_GA) set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}") - set(NVINFER_PLUGIN_LIB "nvinfer_plugin_${NV_TENSORRT_MAJOR}") set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}") endif() @@ -83,15 +82,11 @@ set(NVINFER_LIB "nvinfer") endif() - if (NOT NVINFER_PLUGIN_LIB) - set(NVINFER_PLUGIN_LIB "nvinfer_plugin") - endif() - if (NOT PARSER_LIB) set(PARSER_LIB "nvonnxparser") endif() - MESSAGE(STATUS "Looking for ${NVINFER_LIB} and ${NVINFER_PLUGIN_LIB}") + MESSAGE(STATUS "Looking for ${NVINFER_LIB}") find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB} HINTS ${TENSORRT_ROOT} @@ -101,14 +96,6 @@ MESSAGE(STATUS "Can't find ${NVINFER_LIB}") endif() - find_library(TENSORRT_LIBRARY_INFER_PLUGIN ${NVINFER_PLUGIN_LIB} - HINTS ${TENSORRT_ROOT} - PATH_SUFFIXES lib lib64 lib/x64) - - if (NOT TENSORRT_LIBRARY_INFER_PLUGIN) - MESSAGE(STATUS "Can't find ${NVINFER_PLUGIN_LIB}") - endif() - if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) MESSAGE(STATUS "Looking for ${PARSER_LIB}") @@ -120,7 +107,7 @@ MESSAGE(STATUS "Can't find ${PARSER_LIB}") endif() - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_NVONNXPARSER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") else() if (TRT_GREATER_OR_EQUAL_TRT_10_GA) @@ -153,7 +140,7 @@ endif() # Static libraries are just nvonnxparser_static on all platforms set(onnxparser_link_libs nvonnxparser_static) - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") endif() @@ -161,7 +148,7 @@ # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 # However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries. - # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}. + # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER}. if(onnxruntime_CUDA_MINIMAL) set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) else() diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 5b80b1262464d..2865ad33b39f4 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -172,10 +172,12 @@ file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR}) # Find all WGSL template input files - file(GLOB_RECURSE WGSL_TEMPLATE_FILES "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template") + file(GLOB_RECURSE WGSL_TEMPLATE_FILES + "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template" + "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template") # Set wgsl-gen command line options as a list - set(WGSL_GEN_OPTIONS "-i" "../" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") + set(WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu/" "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") if (onnxruntime_WGSL_TEMPLATE STREQUAL "static") if (CMAKE_BUILD_TYPE STREQUAL "Debug") list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal") @@ -207,10 +209,9 @@ # Add the generated directory to include paths target_include_directories(onnxruntime_providers_webgpu PRIVATE ${WGSL_GENERATED_ROOT}) elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") - target_compile_features(duktape_static PRIVATE c_std_99) target_link_libraries(onnxruntime_providers_webgpu duktape_static) - target_include_directories(onnxruntime_providers_webgpu PRIVATE ${duktape_SOURCE_DIR}/src) + onnxruntime_add_include_to_target(onnxruntime_providers_webgpu duktape_static) + # Define the path to the generated templates.js file target_compile_definitions(onnxruntime_providers_webgpu PRIVATE "ORT_WGSL_TEMPLATES_JS_PATH=\"${WGSL_GENERATED_TEMPLATES_JS}\"") diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index 7c6b2fed36d1b..373ecec440921 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -43,7 +43,6 @@ "ms-gsl", "nlohmann-json", "onnx", - "optional-lite", { "name": "protobuf", "version>=": "3.21.12" @@ -94,6 +93,10 @@ "webgpu-ep": { "description": "Build with WebGPU EP", "dependencies": [] + }, + "webgpu-ep-wgsl-template-dynamic": { + "description": "Build with WebGPU EP with dynamic WGSL template code generator", + "dependencies": ["duktape"] } }, "overrides": [ @@ -104,6 +107,10 @@ { "name": "flatbuffers", "version": "23.5.26" + }, + { + "name": "duktape", + "version": "2.7.0#2" } ] } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs index c28830ec72157..6e6190b8227b8 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs @@ -40,10 +40,12 @@ public void RunPlatformUnitTest() var serializedResultSummary = _app.Invoke(_getResultsBackdoorMethodName)?.ToString(); Assert.IsNotEmpty(serializedResultSummary, "Test results were not returned"); + // Fix security issue (overflow with too much nesting): GHSA-5crp-9r3c-p9vr + JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; var testSummary = JsonConvert.DeserializeObject(serializedResultSummary); Assert.AreEqual(testSummary.Failed, 0, $"{testSummary.Failed} tests failed"); _app.Screenshot("Post-testing"); } } -} \ No newline at end of file +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs index 8419d261e4a41..625cc2c54055c 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs @@ -45,8 +45,9 @@ public TestResultSummary GetResults() public string GetSerializedResults() { var resultSummary = GetResults(); + JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; var serializedResultSummary = JsonConvert.SerializeObject(resultSummary, Formatting.Indented); return serializedResultSummary; } } -} \ No newline at end of file +} diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index b80918e6615e1..f3dcde1abe37a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2545,6 +2545,8 @@ This version of the operator has been available since version 1 of the 'com.micr
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+
qk_output : int
+
Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).
rotary_interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
@@ -2555,7 +2557,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Softcap value for attention weights. Default value is 0.
-#### Inputs (7 - 11) +#### Inputs (7 - 12)
query : T
@@ -2580,9 +2582,11 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element
attention_bias (optional) : T
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
+
head_sink (optional) : T
+
1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.
-#### Outputs +#### Outputs (3 - 4)
output : T
@@ -2591,6 +2595,8 @@ This version of the operator has been available since version 1 of the 'com.micr
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
present_value : T
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
output_qk (optional) : T
+
Values of QK matrix multiplication, either before or after softmax normalization
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1ffcabee8cc10..fa6c731231405 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -538,7 +538,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -942,7 +942,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1420,7 +1420,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 54e03a31fceef..c18a42cc1bbc1 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -952,9 +952,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return const_cast(this)->GetNodeArg(name); } - // search this and up through any parent_graph_ instance for a NodeArg + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name); + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg + const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const; + /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found. @param name The NodeArg name. @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created. diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h new file mode 100644 index 0000000000000..37665542f614f --- /dev/null +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -0,0 +1,718 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/* + SUMMARY: + Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider + implementations that need to convert an OrtGraph instance into an ONNX protobuf model. + + Users may copy this file and modify as needed. + + USAGE: + This is a header-only implementation that includes both the function declarations and definitions. Copy this file + into a project that links with both ONNX Runtime and ONNX. + + Define the ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL preprocessor macro before the #include statement in exactly one C++ + file to define the implementation. Example: + + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + Other compilation units that depend on these utilities should include this file without defining the + preprocessor macro. + + Example program snippets are shown below. Refer to the function declarations for detailed usage information. + + EXAMPLE SNIPPET (initializers stored within TensorProto): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + onnx::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto); + + // graph_proto stores initializers internally + } + ``` + + EXAMPLE SNIPPET (large initializers stored in external file): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + std::string external_file_path = "weights.bin"; + std::ofstream out_file(external_file_path, std::ios::binary); + + auto handle_initializer_data = [&external_file_path, &out_file](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = out_file.tellp(); + location = external_file_path; + out_file.write(static_cast(data), bytes); + out_file.flush(); + is_external = true; // True if is external initializer + return Ort::Status{nullptr}; + } + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); + + // graph_proto stores large initializers in an external file + } + ``` +*/ + +#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ +#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +#include +#include "core/session/onnxruntime_cxx_api.h" +#include "onnx/onnx_pb.h" + +namespace OrtEpUtils { + +/// +/// Signature of user-provided function to handle initializer data. Called by OrtGraphToProto() for every initializer. +/// +/// If the function sets the `is_external` output parameter to false, OrtGraphToProto() stores initializer data +/// within the TensorProto as raw_data. +/// +/// Otherwise, if the function sets `is_external` to true, OrtGraphToProto() assumes that this function stores the +/// initializer data in a file. In this case, OrtGraphToProto() configures the corresponding TensorProto to point the +/// location and offset returned via the `location` and `offset` output parameters. +/// +/// It is recommended to keep small initializers with byte size <= 127 stored inline the TensorProto to ensure +/// ONNX shape inference works correctly with the serialized ONNX model. +/// +/// OrtValueInfo for the initializer. Can be used to query name, type, shape, +/// and consumer nodes. +/// Opaque pointer to the initializer data. +/// Size in bytes of the initializer data. +/// Output parameter set to true if the initializer data is stored externally. The +/// implementer is responsible for writing the initializer data to file. If set to false, +/// the initializer will be stored within the TensorProto. +/// Output parameter set to the location (e.g., file) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// Output parameter set to the offset (e.g., file offset) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// An Ort::Status indicating success or an error. Serialization exits if this returns an error. +using HandleInitializerDataFunc = std::function; + +/// +/// Serializes the provided OrtGraph to a onnx::GraphProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination GraphProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); + +/// +/// Serializes the provided top-level OrtGraph to a onnx::ModelProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination ModelProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); +} // namespace OrtEpUtils + +// End of header +#endif // INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +// +// IMPLEMENTATION BELOW +// +#ifdef ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + +#include +#include +#include +#include +#include +#include + +#define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return Ort::Status{_status}; \ + } \ + } while (0) + +#define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ + do { \ + Ort::Status _status = (fn); \ + if (!_status.IsOK()) { \ + return _status; \ + } \ + } while (0) + +#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ + } \ + } while (0) + +namespace OrtEpUtils { + +static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims); +static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); + +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + const OrtApi& ort_api = Ort::GetApi(); + + // + // Set GraphProto metadata + // + const char* graph_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); + graph_proto.set_name(graph_name); + graph_proto.set_doc_string("Serialized from OrtGraph"); + + // + // Set GraphProto inputs and outputs + // + size_t num_graph_inputs = 0; + size_t num_graph_outputs = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); + + std::vector graph_inputs(num_graph_inputs); + std::vector graph_outputs(num_graph_outputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); + + for (const OrtValueInfo* ort_value_info : graph_inputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); + } + + for (const OrtValueInfo* ort_value_info : graph_outputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); + } + + // + // Set GraphProto nodes, value_infos, and initializers. + // + + // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. + // A std::map maintains its elements in a stable ordering. + std::map value_infos; // For GraphProto.value_info + std::map initializer_value_infos; // For GraphProto.initializer + + // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. + // Optionally returns the OrtValueInfo name to the caller. + auto collect_value_info = [&ort_api, &value_infos, + &initializer_value_infos](const OrtValueInfo& ort_value_info, + /*out*/ const char** value_name_out = nullptr) -> Ort::Status { + const char* value_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); + + if (value_name_out != nullptr) { + *value_name_out = value_name; + } + + if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { + return Ort::Status{nullptr}; // Already processed this OrtValueInfo. + } + + bool is_required_graph_input = false; + bool is_optional_graph_input = false; + bool is_graph_output = false; + bool is_constant_initializer = false; + bool is_from_outer_scope = false; + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); + + // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. + // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. + // For values defined in an outer scope, just add the value info but not the initializer. + if (is_from_outer_scope) { + value_infos.emplace(value_name, &ort_value_info); + } else if (is_optional_graph_input) { + initializer_value_infos.emplace(value_name, &ort_value_info); + } else if (is_constant_initializer) { + value_infos.emplace(value_name, &ort_value_info); + initializer_value_infos.emplace(value_name, &ort_value_info); + } else if (!is_required_graph_input && !is_graph_output) { + value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. + } + + return Ort::Status{nullptr}; + }; + + size_t num_nodes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); + + std::vector nodes(num_nodes); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + + // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos + // that will be stored in GraphProto.value_info and GraphProto.initializer. + for (size_t i = 0; i < num_nodes; i++) { + const OrtNode* ort_node = nodes[i]; + onnx::NodeProto* node_proto = graph_proto.add_node(); + + const char* node_name = nullptr; + const char* node_domain = nullptr; + const char* node_op_type = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); + + node_proto->set_name(node_name); + node_proto->set_domain(node_domain); + node_proto->set_op_type(node_op_type); + + size_t num_inputs = 0; + size_t num_implicit_inputs = 0; + size_t num_outputs = 0; + size_t num_attrs = 0; + size_t num_subgraphs = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); + + // Handle node attributes + if (num_attrs > 0) { + std::vector ort_attrs(num_attrs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); + + for (const OrtOpAttr* ort_attr : ort_attrs) { + OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + + Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + if (!status.IsOK()) { + // This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it. + // Can use Node_GetSubgraphs to get subgraphs. + continue; + } + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + } + } + + // Handle node subgraphs + if (num_subgraphs > 0) { + std::vector ort_subgraphs(num_subgraphs); + std::vector subgraph_attr_names(num_subgraphs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), + subgraph_attr_names.data())); + + for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { + const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; + const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); + + attr_proto->set_name(subgraph_attr_name); + attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); + } + } + + // Handle node inputs + if (num_inputs > 0) { + std::vector ort_inputs(num_inputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_inputs) { + if (ort_value_info == nullptr) { + // missing optional input. + node_proto->add_input(""); + continue; + } + + const char* value_name = nullptr; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); + + node_proto->add_input(value_name); + } + } + + // Handle implicit inputs to this node. + if (num_implicit_inputs > 0) { + std::vector ort_implicit_inputs(num_implicit_inputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), + ort_implicit_inputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { + assert(ort_value_info != nullptr); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); + } + } + + // Handle node outputs + if (num_outputs > 0) { + std::vector ort_outputs(num_outputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_outputs) { + if (ort_value_info == nullptr) { + // missing optional output. + node_proto->add_output(""); + continue; + } + + const char* value_name = nullptr; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); + + node_proto->add_output(value_name); + } + } + } + + // Add value_infos to GraphProto as ValueInfoProto objects. + for (const std::pair& entry : value_infos) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); + } + + // Add initializers to GraphProto as TensorProto objects. + for (const std::pair& entry : initializer_value_infos) { + const OrtValueInfo* initializer_value_info = entry.second; + std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, + initializer_elem_type, initializer_dims, + initializer_sym_dims)); + + onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); + tensor_proto->set_name(initializer_name); + tensor_proto->set_data_type(initializer_elem_type); + + auto* tensor_proto_dims = tensor_proto->mutable_dims(); + for (int64_t dim : initializer_dims) { + tensor_proto_dims->Add(dim); + } + + const OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); + + const void* data = nullptr; + size_t data_bytes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + + std::string ext_location; + int64_t ext_offset = 0; + bool is_external = false; + + if (handle_initializer_data_func != nullptr) { + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, + is_external, ext_location, ext_offset)); + } + + if (is_external) { + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + auto* ext_data_entries = tensor_proto->mutable_external_data(); + onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + + location_entry->set_key("location"); + location_entry->set_value(ext_location); + offset_entry->set_key("offset"); + offset_entry->set_value(std::to_string(ext_offset)); + } else { + // User wants to store data inline the TensorProto's raw_data + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); + tensor_proto->set_raw_data(data, data_bytes); + } + } + + return Ort::Status{nullptr}; +} + +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + const OrtApi& ort_api = Ort::GetApi(); + + // Check that OrtGraph is a top-level graph (no parent node). + const OrtNode* parent_node = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); + ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); + + // Set model description. + model_proto.set_doc_string("Serialized from OrtGraph"); + model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); + + // Set ir version. + int64_t ir_version = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); + model_proto.set_ir_version(ir_version); + + // Set operator sets. + size_t num_operator_sets = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); + ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); + + std::vector domains(num_operator_sets, nullptr); + std::vector opset_versions(num_operator_sets); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), + num_operator_sets)); + + auto* operator_sets = model_proto.mutable_opset_import(); + + for (size_t i = 0; i < num_operator_sets; ++i) { + onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); + operator_set->set_domain(domains[i]); + operator_set->set_version(opset_versions[i]); + } + + model_proto.clear_graph(); + onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); + + return Ort::Status{nullptr}; +} + +static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims) { + const OrtApi& ort_api = Ort::GetApi(); + + const OrtTypeInfo* ort_type_info = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); + + ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); + ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); + + const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); + + size_t num_dims = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); + + std::vector ort_dims(num_dims, 0); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); + + elem_type = ort_elem_type; + dims = std::move(ort_dims); + + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), + ort_dim_syms.size())); + + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } + } + + return Ort::Status{nullptr}; +} + +// Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). +static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, + onnx::ValueInfoProto& value_info_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + std::vector ort_dims; + std::vector ort_dim_syms; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, + ort_elem_type, ort_dims, ort_dim_syms)); + + const char* value_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); + value_info_proto.set_name(value_name); + + onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); + type_proto_tensor->set_elem_type(ort_elem_type); + + onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); + + for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { + onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); + + if (ort_dims[dim_idx] >= 0) { + dim_proto->set_dim_value(ort_dims[dim_idx]); + } else { + const std::string& dim_param = ort_dim_syms[dim_idx]; + + // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, + // which represents an unknown dimension. + if (!dim_param.empty()) { + dim_proto->set_dim_param(dim_param); + } + } + } + + return Ort::Status{nullptr}; +} + +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + const char* attr_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); + attr_proto.set_name(attr_name); + + size_t total_attr_bytes = 0; + OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); + + switch (attr_type) { + case OrtOpAttrType::ORT_OP_ATTR_INT: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + + int64_t i_val = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); + attr_proto.set_i(i_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector i_vals(total_attr_bytes / sizeof(int64_t)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* ints = attr_proto.mutable_ints(); + for (int64_t val : i_vals) { + ints->Add(val); + } + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); + + float f_val = 0.0f; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); + attr_proto.set_f(f_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector f_vals(total_attr_bytes / sizeof(float)); + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* floats = attr_proto.mutable_floats(); + for (float val : f_vals) { + floats->Add(val); + } + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::string* str = attr_proto.mutable_s(); + + str->resize(total_attr_bytes, '\0'); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, + &total_attr_bytes)); + + str->resize(total_attr_bytes - 1); // remove extra ending terminating '\0' character. + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector chars(total_attr_bytes, '\0'); + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* strs = attr_proto.mutable_strings(); + + // Strings are all in a single buffer, each separated with a '\0'. + // Extract each string and add it to the STRINGS attribute array. + char* at = chars.data(); + char* end = at + chars.size(); + + while (at < end) { + char* str_begin = at; + + while (*at && at < end) { + at++; + } + + strs->Add()->assign(str_begin, at - str_begin); + if (at < end) { + assert(*at == '\0'); + at++; // Skip '\0' to get to the beginning of the next string. + } + } + + break; + } + default: { + std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + return Ort::Status{nullptr}; +} + +} // namespace OrtEpUtils +#endif // ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 86c0b60db2bc4..82e782112974f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -66,6 +66,7 @@ extern "C" { #define _In_reads_(X) #define _Inout_updates_(X) #define _Out_writes_(X) +#define _Out_writes_opt_(X) #define _Inout_updates_all_(X) #define _Out_writes_bytes_all_(X) #define _Out_writes_all_(X) @@ -4749,6 +4750,8 @@ struct OrtApi { * \param[in] len Number of bytes allowed to store in data * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success * + * \note Does not support reading graph attributes. Refer to Node_GetSubgraphs. + * * \since Version 1.17. */ ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); @@ -5568,6 +5571,45 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); + /** \brief Returns the number of operator sets that the graph's model uses. + * + * \note An operator set is uniquely identified by the (domain, opset_version) pair. All models must have at + * least one entry that specifies which entry of the ONNX operator set is used. The ONNX domain is represented by + * an empty string. + * + * \param[in] graph The OrtGraph instance. + * \param[out] num_operator_sets Output parameter set to the number of operator sets that the graph's model uses. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); + + /** \brief Returns the operator sets that the graph's model uses. + * + * \note An operator set is uniquely identified by the (domain, opset_version) pair. All models must have at + * least one entry that specifies which entry of the ONNX operator set is used. The ONNX domain is represented by + * an empty string. + * + * \param[in] graph The OrtGraph instance. + * \param[out] domains Pre-allocated array of `num_operator_sets` elements that is filled with + * null-terminated domain names. + * \param[out] opset_versions Pre-allocated array of `num_operator_sets` elements that is filled with + * the opset version of the corresponding domain in the `domains` array. + * \param[in] num_operator_sets The size of the `domains` and `opset_versions` arrays. + * Typical usage sets this to the result of Graph_GetNumOperatorSets(). + * An error status is returned if `num_operator_sets` is less than the actual number + * of operator sets. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetOperatorSets, _In_ const OrtGraph* graph, + _Out_writes_(num_operator_sets) const char** domains, + _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets); + /** \brief Returns the number of graph inputs. * * \note The count includes initializers that are included in the list of graph inputs. @@ -5706,6 +5748,24 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); + /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. + * + * Note: + * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * the same underlying graph. + * + * \param[in] src_graph The source OrtGraph instance. + * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. + * \param[in] num_nodes Number of nodes. + * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes, + _In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph); + /// @} /// \name OrtNode @@ -5933,20 +5993,24 @@ struct OrtApi { /** \brief Get the subgraphs, as OrtGraph instances, contained by the given node. * - * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. + * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. ONNX nodes store subgraphs in + * their attributes, however, this function must be used to obtain subgraphs from an OrtNode. * * \param[in] node The OrtNode instance. * \param[out] subgraphs Pre-allocated array of `num_subgraphs` elements that is filled with the node's subgraphs. * \param[in] num_subgraphs The size of the `num_subgraphs` array. * Typical usage sets this to the result of Node_GetNumSubgraphs(). An error status is * returned if `num_subgraphs` is less than the number of node subgraphs. + * \param[out] attribute_names Optional pre-allocated array of `num_subgraphs` elements that is filled with the + * attribute names that correspond to the subgraphs. Ignored if set to NULL. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, + _Out_writes_opt_(num_subgraphs) const char** attribute_names); /** \brief Get the node's parent OrtGraph instance. * @@ -5962,6 +6026,19 @@ struct OrtApi { */ ORT_API2_STATUS(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); + /** \brief Returns the execution provider name that this node is assigned to run on. + * Returns NULL if the node has not been assigned to any execution provider yet. + * For plugin execution providers, the name is the one returned by OrtEp::GetName. + * + * \param[in] node The OrtNode instance. + * \param[out] out Output execution provider type and can be NULL if node has not been assigned. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); + /// @} /// \name OrtRunOptions @@ -6810,6 +6887,24 @@ struct OrtCompileApi { */ ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, size_t flags); + + /** Sets information related to EP context binary file. + * + * EP uses this information to decide the location and context binary file name. + * Used while compiling model with input and output in memory buffer + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] output_directory Null terminated string of the path (wchar on Windows, char otherwise). + * \param[in] model_name Null terminated string of the model name (wchar on Windows, char otherwise). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetEpContextBinaryInformation, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* output_directory, + _In_ const ORTCHAR_T* model_name); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index c59baa59c91a5..d1b08f127fa2a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1161,6 +1161,8 @@ struct ModelCompilationOptions : detail::Base { size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, + const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 612adc81d3309..ba5d53e6c2dd0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -819,6 +819,15 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath( return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextBinaryInformation( + const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextBinaryInformation( + this->p_, + output_directory, + model_name)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalInitializersFile( const ORTCHAR_T* file_path, size_t initializer_size_threshold) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelExternalInitializersFile( diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 44c7bb6ee424a..5d00ce4940d02 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -358,7 +358,7 @@ struct OrtEp { * * \since Version 1.22. */ - const char*(ORT_API_CALL* GetName)(_In_ const OrtEp* this_ptr); + ORT_API_T(const char*, GetName, _In_ const OrtEp* this_ptr); /** \brief Get information about the nodes supported by the OrtEp instance. * @@ -376,8 +376,8 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* GetCapability)(_In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, - _Inout_ OrtEpGraphSupportInfo* graph_support_info); + ORT_API2_STATUS(GetCapability, _In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, + _Inout_ OrtEpGraphSupportInfo* graph_support_info); /** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance * for each OrtGraph in order to define its computation function. @@ -416,10 +416,10 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* Compile)(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, - _In_ const OrtNode** fused_nodes, _In_ size_t count, - _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes); + ORT_API2_STATUS(Compile, _In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes); /** \brief Release OrtNodeComputeInfo instances. * @@ -429,9 +429,9 @@ struct OrtEp { * * \since Version 1.23. */ - void(ORT_API_CALL* ReleaseNodeComputeInfos)(_In_ OrtEp* this_ptr, - OrtNodeComputeInfo** node_compute_infos, - _In_ size_t num_node_compute_infos); + ORT_API_T(void, ReleaseNodeComputeInfos, _In_ OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + _In_ size_t num_node_compute_infos); /** \brief Get the EP's preferred data layout. * @@ -445,8 +445,7 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* GetPreferredDataLayout)(_In_ OrtEp* this_ptr, - _Out_ OrtEpDataLayout* preferred_data_layout); + ORT_API2_STATUS(GetPreferredDataLayout, _In_ OrtEp* this_ptr, _Out_ OrtEpDataLayout* preferred_data_layout); /** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout * should be converted to `target_data_layout`. @@ -470,11 +469,10 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* ShouldConvertDataLayoutForOp)(_In_ OrtEp* this_ptr, - _In_z_ const char* domain, - _In_z_ const char* op_type, - _In_ OrtEpDataLayout target_data_layout, - _Outptr_ int* should_convert); + ORT_API2_STATUS(ShouldConvertDataLayoutForOp, _In_ OrtEp* this_ptr, + _In_z_ const char* domain, _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert); /** \brief Set dynamic options on this EP. * @@ -492,10 +490,10 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* SetDynamicOptions)(_In_ OrtEp* this_ptr, - _In_reads_(num_options) const char* const* option_keys, - _In_reads_(num_options) const char* const* option_values, - _In_ size_t num_options); + ORT_API2_STATUS(SetDynamicOptions, _In_ OrtEp* this_ptr, + _In_reads_(num_options) const char* const* option_keys, + _In_reads_(num_options) const char* const* option_values, + _In_ size_t num_options); /** \brief Called by ORT to notify the EP of the start of a run. * @@ -508,8 +506,7 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* OnRunStart)(_In_ OrtEp* this_ptr, - _In_ const OrtRunOptions* run_options); + ORT_API2_STATUS(OnRunStart, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options); /** \brief Called by ORT to notify the EP of the end of a run. * @@ -524,9 +521,7 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* OnRunEnd)(_In_ OrtEp* this_ptr, - _In_ const OrtRunOptions* run_options, - _In_ bool sync_stream); + ORT_API2_STATUS(OnRunEnd, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options, _In_ bool sync_stream); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. @@ -586,7 +581,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); + ORT_API_T(const char*, GetName, const OrtEpFactory* this_ptr); /** \brief Get the name of vendor who owns the execution provider that the factory creates. * @@ -597,7 +592,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - const char*(ORT_API_CALL* GetVendor)(const OrtEpFactory* this_ptr); // return EP vendor + ORT_API_T(const char*, GetVendor, const OrtEpFactory* this_ptr); // return EP vendor /** \brief Get information from the execution provider about OrtHardwareDevice support. * @@ -616,12 +611,12 @@ struct OrtEpFactory { * * \since Version 1.22. */ - OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices); + ORT_API2_STATUS(GetSupportedDevices, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); /** \brief Function to create an OrtEp instance for use in a Session. * @@ -647,12 +642,12 @@ struct OrtEpFactory { * * \since Version 1.22. */ - OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); + ORT_API2_STATUS(CreateEp, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); /** \brief Release the OrtEp instance. * @@ -661,7 +656,18 @@ struct OrtEpFactory { * * \since Version 1.22. */ - void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); + ORT_API_T(void, ReleaseEp, OrtEpFactory* this_ptr, struct OrtEp* ep); + + /** \brief Get the vendor id who owns the execution provider that the factory creates. + * + * This is typically the PCI vendor ID. See https://pcisig.com/membership/member-companies + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return vendor_id The vendor ID of the execution provider the factory creates. + * + * \since Version 1.23. + */ + ORT_API_T(uint32_t, GetVendorId, const OrtEpFactory* this_ptr); /** \brief Get the version of the execution provider that the factory creates. * @@ -675,7 +681,7 @@ struct OrtEpFactory { * * \since Version 1.23. */ - const char*(ORT_API_CALL* GetVersion)(_In_ const OrtEpFactory* this_ptr); + ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); /** \brief Create an OrtAllocator for the given OrtMemoryInfo. * diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 97e53e6acee5a..314cf76cc8044 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -148,7 +148,9 @@ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = " // Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking // "0": thread will block if found no job to run -// "1": default, thread will spin a number of times before blocking +// "1": thread will spin a number of times before blocking +// The default is "0" when ORT is built with "ORT_CLIENT_PACKAGE_BUILD" and "1" otherwise. +// Thread spinning is disabled by default for client/on-device workloads to reduce cpu utilization and improve power efficiency. static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning"; static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning"; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 5a837fd1e0bfa..c2085342efd80 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -98,7 +98,7 @@ const calculateInputIndicesImpl = ( `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { var input_indices: ${input.type.indices}; var carry = 0u; - for (var i = ${inputShape.length}; i >= 0; i--) { + for (var i = ${inputShape.length - 1}; i >= 0; i--) { let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)}; let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)}; diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index c3300f7272bb9..87008f51ff4b9 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -38,7 +38,6 @@ Usage: Options: -d --debug specify the debug build type of the artifacts to download. -l --latest if set, will always use the latest build, even if it is not completed yet. - --webgpu-ep if set, will use the webgpu EP wasm build instead of the default(JSEP) one. -h --help print this message and exit `; @@ -81,9 +80,8 @@ try { // The following code checks both the command line arguments and the npm_config_* environment variables to get the correct values. const debug = args.debug || process.env.npm_config_d || process.env.npm_config_debug; const latest = args.latest || process.env.npm_config_l || process.env.npm_config_latest; -const webgpuEp = args['webgpu-ep'] || process.env.npm_config_webgpu_ep; -const folderName = (debug ? 'Debug_wasm' : 'Release_wasm') + (webgpuEp ? '_webgpu' : ''); +const folderName = debug ? 'Debug_wasm' : 'Release_wasm'; const allowImcomplete = latest; const run = args._[0]; // The first non-option argument @@ -151,13 +149,17 @@ async function downloadArtifactsForRun(run: any): Promise { if (!fs.existsSync(WASM_FOLDER)) { fs.mkdirSync(WASM_FOLDER); } else { - // TODO: revise artifacts download - const filesToDelete = ['ort-wasm-simd-threaded.jsep.mjs', 'ort-wasm-simd-threaded.jsep.wasm']; - if (!folderName.endsWith('_webgpu')) { - filesToDelete.push('ort-wasm-simd-threaded.mjs', 'ort-wasm-simd-threaded.wasm'); - } fs.readdirSync(WASM_FOLDER).forEach((file) => { - if (filesToDelete.includes(file)) { + if ( + [ + 'ort-wasm-simd-threaded.jsep.mjs', + 'ort-wasm-simd-threaded.jsep.wasm', + 'ort-wasm-simd-threaded.jsep.mjs', + 'ort-wasm-simd-threaded.jsep.wasm', + 'ort-wasm-simd-threaded.mjs', + 'ort-wasm-simd-threaded.wasm', + ].includes(file) + ) { const filePath = path.join(WASM_FOLDER, file); console.log(`Deleting old file: ${filePath}`); fs.unlinkSync(filePath); diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 243f611da49e1..80d374d3f0b25 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -53,6 +53,12 @@ enum AttentionKernelType { AttentionKernel_Default }; +enum class QKOutputType : int { + NO_OUTPUT = 0, + BEFORE_SOFTMAX = 1, + AFTER_SOFTMAX = 2 +}; + constexpr bool LAYOUT_BSNH = false; constexpr bool LAYOUT_BNSH = true; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index ac32a4445f3ca..aef47edd5fcd2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -17,13 +17,13 @@ namespace onnxruntime { namespace contrib { template -inline void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, true, tp); +inline void ComputeSmoothSoftmaxInplace(T* score, int D, float sink, ThreadPool* tp) { + MlasComputeSoftmax(score, score, 1, D, false, true, sink, tp); } template inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, false, tp); + MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp); } template diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index c79508cbae273..0d5117709c18a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -35,6 +35,8 @@ class GQAAttentionBase { use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; local_window_size_ = has_local ? static_cast(info.GetAttrOrDefault("local_window_size", -1)) : -1; + + qk_output_ = static_cast(info.GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))); } int num_heads_; // number of attention heads of Q @@ -44,6 +46,7 @@ class GQAAttentionBase { bool do_rotary_; // whether or not to use rotary embeddings bool rotary_interleaved_; int local_window_size_; + int qk_output_; bool use_smooth_softmax_; @@ -51,12 +54,14 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH + const T* head_sink, // Head sink for smooth softmax, nullptr if not used const Tensor* attention_bias, // Attention bias to add to QxK' const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) Tensor* output, // output tensor Tensor* present_key, // present K output tensor (if separating present KV) Tensor* present_value, // present V output tensor (if separating present KV) + Tensor* output_qk, // output QK buffer const Tensor* seqlens_k, // past sequence lengths tensor GroupQueryAttentionParameters& parameters, // attention parameters AllocatorPtr allocator, // allocator for temporary tensors @@ -64,6 +69,7 @@ class GQAAttentionBase { const bool is_prompt = parameters.is_first_prompt; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; const int head_size = parameters.head_size; const int hidden_size = parameters.hidden_size; const bool packed_qkv = parameters.is_packed_qkv; @@ -79,8 +85,7 @@ class GQAAttentionBase { // Compute the attention score. bool gqa_mlas_supported = MlasGQASupported(CblasNoTrans, CblasTrans) && MlasGQASupported(CblasNoTrans, CblasNoTrans); - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * - (gqa_mlas_supported ? sizeof(T) : sizeof(float)); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * (gqa_mlas_supported ? sizeof(T) : sizeof(float)); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -96,11 +101,13 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + T* output_qk_buffer = output_qk != nullptr ? output_qk->MutableData() : nullptr; + if (gqa_mlas_supported) { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, - batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, - head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, - tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, + seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, + past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -110,10 +117,10 @@ class GQAAttentionBase { hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); } else { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, - batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, - head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, - tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, + seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, + past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -136,16 +143,19 @@ class GQAAttentionBase { void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH + const T* head_sink, // for smooth softmax. Its size is N. const int32_t* seqlens_k, // total - 1 sequence lengths tensor const T* attention_bias, // optional attention bias const size_t batch_size, // batch size of self-attention const size_t sequence_length, // sequence length of self-attention (S) + const size_t total_sequence_length, // total sequence length (T) const gsl::span attention_bias_shape, // shape of the attention bias const size_t past_buffer_sequence_length, // sequence length of past state const size_t present_buffer_sequence_length, // sequence length of present state const size_t head_size, // head size of self-attention const T* past_key, // past key only T* present_key, // present key only + T* output_qk, // output QK buffer const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt @@ -197,6 +207,11 @@ class GQAAttentionBase { const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; U* output = attention_probs + output_offset; + T* output_qk_thread = nullptr; + if (output_qk != nullptr) { + const ptrdiff_t output_qk_offset = SafeInt(sequence_length) * total_sequence_length * (batch_index * num_heads_ + head_index); + output_qk_thread = output_qk + output_qk_offset; + } // Compute attention bias offset based on the batch and head indexes // Attention bias is of shape (B or 1, H or 1, S, T) so handle broadcasting @@ -310,12 +325,6 @@ class GQAAttentionBase { } } - if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); - } else { - ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); - } - // set causal [seq_causal_length, total_seqlen) to 0.f for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { if constexpr (std::is_same::value) { @@ -325,11 +334,30 @@ class GQAAttentionBase { } } + if (qk_output_ == static_cast(QKOutputType::BEFORE_SOFTMAX)) { + WriteOutputQKHeadChunk(output_qk_thread, output_softmax, total_sequence_length); + } + + if (use_smooth_softmax_ || head_sink != nullptr) { + float sink = (head_sink != nullptr) ? static_cast(head_sink[head_index]) : 0.0f; + ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast(window_size), sink, nullptr); + } else { + ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); + } + + if (qk_output_ == static_cast(QKOutputType::AFTER_SOFTMAX)) { + WriteOutputQKHeadChunk(output_qk_thread, output_softmax, total_sequence_length); + } + output_softmax += present_buffer_sequence_length; if (attention_bias_thread != nullptr) { attention_bias_thread += attention_total_seqlen; } + + if (output_qk_thread != nullptr) { + output_qk_thread += total_sequence_length; + } } } }); @@ -455,6 +483,20 @@ class GQAAttentionBase { SafeInt(sequence_length) * batch_size * num_heads_ * head_size); } } + + template + void WriteOutputQKHeadChunk(T* output_qk, const U* attention_probs, size_t total_sequence_length) const { + if (output_qk == nullptr) { + return; + } + + if constexpr (std::is_same_v) { + std::memcpy(output_qk, attention_probs, SafeInt(total_sequence_length) * sizeof(T)); + } else { + static_assert(std::is_same_v && std::is_same_v); + MlasConvertFloatToHalfBuffer(static_cast(attention_probs), output_qk, total_sequence_length); + } + } }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index a912bd6e6b43c..eb1560ac8e341 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -95,6 +95,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { Tensor* present_k = context->Output(1, present_k_shape); Tensor* present_v = context->Output(2, present_v_shape); + std::vector output_qk_shape{static_cast(batch_size), static_cast(num_heads_), static_cast(parameters.sequence_length), static_cast(parameters.total_sequence_length)}; + Tensor* output_qk = context->Output(3, output_qk_shape); + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckOutputs(output_qk, qk_output_)); + AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -206,10 +211,12 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr; + // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - attention_bias, past_key, past_value, output, present_k, present_v, - seqlens_k, parameters, allocator, context); + head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, + output_qk, seqlens_k, parameters, allocator, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 0f66119540b03..f01ce985658aa 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -398,6 +398,37 @@ Status CheckCustomAttentionInputs(const T* position_ids, return Status::OK(); } +template +Status CheckOutputs(const T* output_qk, int qk_output) { + const bool is_valid_qk_output = qk_output == static_cast(QKOutputType::NO_OUTPUT) || + qk_output == static_cast(QKOutputType::BEFORE_SOFTMAX) || + qk_output == static_cast(QKOutputType::AFTER_SOFTMAX); + if (!is_valid_qk_output) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "qk_output attribute received unsupported value ", qk_output); + } + + if (qk_output != static_cast(QKOutputType::NO_OUTPUT) && output_qk == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "qk_output attribute was configured but output buffer was not provided"); + } + + return Status::OK(); +} + +inline Status CheckNoQKOutput(int num_outputs, int qk_output) { + if (num_outputs > 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "output_qk optional output is not supported"); + } + + if (qk_output != static_cast(QKOutputType::NO_OUTPUT)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "qk_output attribute is not supported"); + } + + return Status::OK(); +} + } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 68c4b01d2db20..9cb93cbcd3f32 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -109,6 +109,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; + // The current GQA CUDA implementation will never be able to have a QK output. + // GQA CUDA uses either flash attention or memory efficient attention. Neither kernel supports returning the QK output. + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context->OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 85aef55908506..09a6550549614 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -213,6 +213,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context->OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index f3334b13dc645..1f039177b0a21 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -178,6 +178,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& head_sink, params)); + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context.OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + WebgpuAttentionParameters parameters(params); TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 8ea593f107833..c4667d53c0674 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -170,7 +170,7 @@ std::string CPUIDInfo::GetX86Vendor(int32_t* data) { uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { if (vendor == "GenuineIntel") return 0x8086; - if (vendor == "GenuineAMD") return 0x1022; + if (vendor == "AuthenticAMD") return 0x1022; if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); if (vendor.find("NV") == 0) return 0x10DE; return 0; diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index c3dd9321ebb0b..47fbe08da41ff 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -247,8 +247,11 @@ struct OrtNode { /// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node). ///
/// Buffer into which to copy the subgraphs. + /// Optional buffer into which to copy the attribute name for each subgraph. + /// If set, must point to a buffer with the same number of elements as `subgraphs`. /// A status indicating success or an error. - virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs) const = 0; + virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs, + const char** opt_attribute_names) const = 0; /// /// Gets the node's parent graph, which is the graph that contains this node. @@ -280,6 +283,23 @@ struct OrtGraph { /// The model's ONNX IR version. virtual int64_t GetOnnxIRVersion() const = 0; + /// + /// Gets the number of operator sets (domain, opset version) the graph's model relies on. + /// + /// Output parameter set to the number of operator sets. + /// A status indicating success or an error. + virtual onnxruntime::Status GetNumOperatorSets(size_t& num_operator_sets) const = 0; + + /// + /// Gets the operator sets the graph's model relies on. An operator set is uniquely identified by a + /// (domain, opset version) pair. + /// + /// Buffer into which to copy the domains. + /// Buffer into which to copy the opset version for each domain. + /// A status indicating success or an error. + virtual onnxruntime::Status GetOperatorSets(gsl::span domains, + gsl::span opset_versions) const = 0; + /// /// Returns the number of graph inputs, including initializers that appear in the list of graph inputs. /// diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index f2757c2c96471..e2b17aa84d2b1 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -6,6 +6,7 @@ #include "core/graph/contrib_ops/quantization_defs.h" #include "core/graph/contrib_ops/onnx_function_util.h" #include "core/graph/contrib_ops/shape_inference_functions.h" +#include "contrib_ops/cpu/bert/attention_common.h" // Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from // ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build #if defined(_WIN32) && !defined(NDEBUG) @@ -232,7 +233,8 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c // Type and shape inference for group query attention and sparse attention. void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index = -1, - int use_max_past_present_buffer = -1) { + int use_max_past_present_buffer = -1, + int output_qk_index = -1) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); int64_t kv_sequence_length = -1; @@ -277,13 +279,20 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte } } - if (ctx.getNumOutputs() > 1) { // has present output + if (ctx.getNumOutputs() >= 3) { // has present output // copy the type from query to present key ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); // copy the type from query to present value ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); + int64_t total_sequence_length_value = 0; + const auto* total_sequence_length_data = ctx.getInputData(6); + if (total_sequence_length_data != nullptr) { + const auto& data = ParseData(total_sequence_length_data); + total_sequence_length_value = static_cast(data[0]); + } + if (past_key_index >= 0 && hasInputShape(ctx, past_key_index)) { auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); @@ -299,30 +308,25 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); } else if (use_max_past_present_buffer == 0) { if (kv_sequence_length > 0 && past_dims[2].has_dim_value()) { - int64_t total_sequence_length = kv_sequence_length + past_dims[2].dim_value(); + const int64_t present_sequence_length = kv_sequence_length + past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { *present_shape.add_dim() = dim; } - // shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size) - present_shape.mutable_dim(2)->set_dim_value(total_sequence_length); + // shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size) + present_shape.mutable_dim(2)->set_dim_value(present_sequence_length); updateOutputShape(ctx, 1, present_shape); updateOutputShape(ctx, 2, present_shape); } } else if (use_max_past_present_buffer == -1) { - const auto* total_sequence_length_data = ctx.getInputData(6); - if (total_sequence_length_data != nullptr && past_dims[2].has_dim_value()) { - int64_t total_sequence_length_value = 0; - const auto& data = ParseData(total_sequence_length_data); - total_sequence_length_value = static_cast(data[0]); - + if (total_sequence_length_value > 0 && past_dims[2].has_dim_value()) { // present_sequence_length = max(past_sequence_length, total_sequence_length) - int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() - ? total_sequence_length_value - : past_dims[2].dim_value(); + const int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() + ? total_sequence_length_value + : past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { @@ -336,19 +340,50 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte updateOutputShape(ctx, 2, present_shape); } } + + if (output_qk_index >= 0) { + const bool did_supply_qk_buffer = ctx.hasOutput(output_qk_index); + const int64_t qk_output_type = getAttribute(ctx, "qk_output", static_cast(QKOutputType::NO_OUTPUT)); + + if (qk_output_type == static_cast(QKOutputType::NO_OUTPUT) && did_supply_qk_buffer) { + fail_shape_inference("Output QK buffer was provided but qk_output attribute was not configured"); + } + + if (qk_output_type != static_cast(QKOutputType::NO_OUTPUT) && !did_supply_qk_buffer) { + fail_shape_inference("Output QK buffer was not provided but qk_output attribute was configured"); + } + + int64_t num_heads = getAttribute(ctx, "num_heads", 0); + if (did_supply_qk_buffer && hasInputShape(ctx, 0) && total_sequence_length_value > 0 && num_heads > 0) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, output_qk_index); + + auto& query_shape = getInputShape(ctx, 0); + auto& query_dims = query_shape.dim(); + + if (query_dims[0].has_dim_value() && query_dims[1].has_dim_value()) { + ONNX_NAMESPACE::TensorShapeProto output_qk_shape; + *output_qk_shape.add_dim() = query_dims[0]; // batch_size + output_qk_shape.add_dim()->set_dim_value(num_heads); // num_heads + *output_qk_shape.add_dim() = query_dims[1]; // sequence_length + output_qk_shape.add_dim()->set_dim_value(total_sequence_length_value); // total_sequence_length + updateOutputShape(ctx, output_qk_index, output_qk_shape); + } + } + } } } } -void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { +void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index, int qk_output_index) { // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not constexpr int use_max_past_present_buffer = -1; - BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); + BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer, qk_output_index); } void SparseAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { constexpr int use_max_past_present_buffer = 1; - BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); + constexpr int qk_output_index = -1; + BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer, qk_output_index); } constexpr const char* Attention_ver1_doc = R"DOC( @@ -1127,6 +1162,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Use a smooth factor in softmax.", AttributeProto::INT, static_cast(-1)) + .Attr("qk_output", + "Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).", + AttributeProto::INT, + static_cast(QKOutputType::NO_OUTPUT)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" @@ -1184,6 +1223,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) + .Input(11, + "head_sink", + "1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", @@ -1200,10 +1244,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T") + .Output(3, + "output_qk", + "Values of QK matrix multiplication, either before or after softmax normalization", + "T", + OpSchema::Optional) .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - GroupQueryAttentionTypeAndShapeInference(ctx, 3); + GroupQueryAttentionTypeAndShapeInference(ctx, 3, 3); })); constexpr const char* PagedAttention_ver1_doc = R"DOC( diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 698c7422a1e2a..f57543416a68f 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -129,11 +129,12 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_implicit_inputs, ep_node_implicit_inputs); - std::vector> node_subgraphs = node.GetSubgraphs(); - ep_node_subgraphs.reserve(node_subgraphs.size()); + std::unordered_map> subgraphs_map = node.GetAttributeNameToSubgraphMap(); + ep_node_subgraphs.reserve(subgraphs_map.size()); - for (gsl::not_null subgraph : node_subgraphs) { + for (const auto& [attr_name, subgraph] : subgraphs_map) { SubgraphState subgraph_state; + subgraph_state.attribute_name = attr_name; subgraph_state.subgraph_viewer = std::make_unique(*subgraph); ORT_RETURN_IF_ERROR(EpGraph::Create(*subgraph_state.subgraph_viewer, subgraph_state.ep_subgraph)); subgraph_state.ep_subgraph->SetParentNode(ep_node.get()); @@ -233,12 +234,17 @@ Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { return Status::OK(); } -Status EpNode::GetSubgraphs(gsl::span dst) const { +Status EpNode::GetSubgraphs(gsl::span subgraphs, + const char** opt_attribute_names) const { const size_t num_subgraphs = subgraphs_.size(); - ORT_RETURN_IF_ERROR((CheckCopyDestination("node attributes", num_subgraphs, dst))); + ORT_RETURN_IF_ERROR((CheckCopyDestination("node subgraphs", num_subgraphs, subgraphs))); for (size_t i = 0; i < num_subgraphs; ++i) { - dst[i] = subgraphs_[i].ep_subgraph.get(); + subgraphs[i] = subgraphs_[i].ep_subgraph.get(); + + if (opt_attribute_names) { + opt_attribute_names[i] = subgraphs_[i].attribute_name.c_str(); + } } return Status::OK(); @@ -270,6 +276,10 @@ const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { } } +const std::string& EpNode::GetEpName() const { + return node_.GetExecutionProviderType(); +} + // // EpValueInfo // @@ -499,10 +509,34 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node) EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag) : OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {} +EpGraph::EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag) + : OrtGraph(OrtGraphIrApi::kEpApi), + graph_viewer_(*graph_viewer.get()), + owned_graph_viewer_(std::move(graph_viewer)), + owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {} + // Static class function to create a std::unique_ptr. Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { auto ep_graph = std::make_unique(graph_viewer, PrivateTag{}); + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +// Static class function to create a std::unique_ptr. +Status EpGraph::Create(std::unique_ptr src_graph_viewer, + std::unique_ptr src_indexed_sub_graph, + /*out*/ std::unique_ptr& result) { + auto& graph_viewer = *src_graph_viewer.get(); + auto ep_graph = std::make_unique(std::move(src_graph_viewer), + std::move(src_indexed_sub_graph), + PrivateTag{}); + + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance(); std::unordered_map> value_infos_map; @@ -660,6 +694,43 @@ const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); } int64_t EpGraph::GetOnnxIRVersion() const { return graph_viewer_.GetOnnxIRVersion(); } +Status EpGraph::GetNumOperatorSets(size_t& num_operator_sets) const { + num_operator_sets = graph_viewer_.DomainToVersionMap().size(); + return Status::OK(); +} + +Status EpGraph::GetOperatorSets(gsl::span domains, + gsl::span opset_versions) const { + const std::unordered_map& domain_to_version = graph_viewer_.DomainToVersionMap(); + size_t num_operator_sets = domain_to_version.size(); + + ORT_RETURN_IF_ERROR((CheckCopyDestination("operator set domains", num_operator_sets, domains))); + ORT_RETURN_IF_ERROR((CheckCopyDestination("operator set versions", num_operator_sets, opset_versions))); + + // Collect (domain, version) pairs and sort them by domain to ensure user always gets a stable ordering. + std::vector> pairs; + pairs.reserve(num_operator_sets); + + for (const auto& [domain, version] : domain_to_version) { + pairs.emplace_back(domain.c_str(), version); + } + + std::sort(pairs.begin(), pairs.end(), + [](const std::pair& a, const std::pair& b) -> bool { + return std::strcmp(a.first, b.first) < 0; + }); + + // Copy sorted (domain, version) pairs into the destination buffers. + size_t index = 0; + for (const auto& [domain_c_str, version] : pairs) { + domains[index] = domain_c_str; + opset_versions[index] = version; + index++; + } + + return Status::OK(); +} + size_t EpGraph::GetNumInputs() const { return inputs_.size(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 4240f5636b7ae..d3921e051e18a 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -111,6 +111,7 @@ struct EpNode : public OrtNode { struct SubgraphState { SubgraphState() = default; SubgraphState(SubgraphState&& other) = default; + std::string attribute_name; std::unique_ptr subgraph_viewer; // The graph_viewer wrapped by EpGraph below. std::unique_ptr ep_subgraph; }; @@ -182,7 +183,8 @@ struct EpNode : public OrtNode { Status GetNumSubgraphs(size_t& num_subgraphs) const override; // Gets the subgraphs contained by this node. - Status GetSubgraphs(gsl::span subgraphs) const override; + Status GetSubgraphs(gsl::span subgraphs, + const char** opt_attribute_names) const override; // Gets this node's parent graph, which is the graph that directly contains this node. Status GetGraph(const OrtGraph*& parent_graph) const override; @@ -206,6 +208,9 @@ struct EpNode : public OrtNode { // Helper that gets the node's attributes by name. const OrtOpAttr* GetAttribute(const std::string& name) const; + // Helper that gets the execution provider name that this node is assigned to run on. + const std::string& GetEpName() const; + private: // Back pointer to containing graph. Useful when traversing through nested subgraphs. // Will be nullptr if the EpNode was created without an owning graph. @@ -249,15 +254,32 @@ struct EpGraph : public OrtGraph { public: EpGraph(const GraphViewer& graph_viewer, PrivateTag); + EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag); /// /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph. /// /// /// /// static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + /// + /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a subset of nodes in another EpGraph. + /// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance + /// must take ownership of both the GraphViewer and IndexedSubGraph. + /// + /// + /// + /// + static Status Create(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + /*out*/ std::unique_ptr& result); + // Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph. DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi) @@ -271,6 +293,14 @@ struct EpGraph : public OrtGraph { // Returns the model's ONNX IR version. int64_t GetOnnxIRVersion() const override; + // Gets the number of operator sets that the graph's model uses. + Status GetNumOperatorSets(size_t& num_operator_sets) const override; + + // Gets the operator sets that the graph's model uses. An operator set is uniquely identified by a + // (domain, opset version) pair. + Status GetOperatorSets(gsl::span domains, + gsl::span opset_versions) const override; + // Get the number of graph inputs, including initializers that are listed as graph inputs. size_t GetNumInputs() const override; @@ -321,9 +351,22 @@ struct EpGraph : public OrtGraph { const OrtValue* GetInitializerValue(std::string_view name) const; private: + /// + /// The real implementation of creating an EpGraph instance. + /// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly. + /// + /// + /// + /// + /// + static Status CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + const GraphViewer& graph_viewer_; const EpNode* parent_node_ = nullptr; + std::unique_ptr owned_graph_viewer_ = nullptr; + std::unique_ptr owned_indexed_sub_graph_ = nullptr; + std::vector> nodes_; IndexToEpNodeMap index_to_ep_node_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index ca40bad2b4250..4d3091520d876 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,6 +1818,10 @@ NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name return node_arg; } +const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const { + return const_cast(this)->GetNodeArgIncludingParentGraphs(node_arg_name); +} + void Graph::ReverseDFSFrom(gsl::span from, const std::function& enter, const std::function& leave, diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 1842c2b4a0d1f..948ebaa5f7e15 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -168,7 +168,15 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) filtered_node_inputs_including_initializers_.reserve(metadef->inputs.size()); for (const auto& input : metadef->inputs) { - const auto* nodearg = graph.GetNodeArg(input); + // NodeArgs from the current scope or any outer scopes should be handled correctly. + // + // There is an edge case where the model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + // When constructing a new GraphViewer for the second- and third-layer subgraphs, + // the second-layer graph may not have the corresponding value_info for that first-layer input, + // because the second-layer graph itself doesn't consume it. + // Therefore, when working within the second-layer graph, we need to search outer scopes for the missing value_info. + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(input); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Input not found:", input); filtered_node_inputs_including_initializers_.push_back(nodearg); if (!graph.IsInitializedTensor(input)) { @@ -177,7 +185,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } for (const auto& output : metadef->outputs) { - const auto* nodearg = graph.GetNodeArg(output); + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(output); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Output not found:", output); filtered_node_outputs_.push_back(nodearg); } diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 6330a42c115db..6e7e17374bb59 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -136,7 +136,8 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } - Status GetSubgraphs(gsl::span /*subgraphs*/) const override { + Status GetSubgraphs(gsl::span /*subgraphs*/, + const char** /*opt_attribute_names*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } @@ -176,6 +177,17 @@ struct ModelEditorGraph : public OrtGraph { return ONNX_NAMESPACE::Version::IR_VERSION; } + Status GetNumOperatorSets(size_t& /*num_operator_sets*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the graph's operator sets."); + } + + Status GetOperatorSets(gsl::span /*domains*/, + gsl::span /*opset_versions*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the graph's operator sets."); + } + size_t GetNumInputs() const override { return inputs.size(); } Status GetInputs(gsl::span /*result*/) const override { diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 3575e30721af7..4d85c35461825 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1020,6 +1020,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1223,6 +1224,21 @@ MlasQuantizeLinearS4( int8_t ZeroPoint ); +// +// Linear dequantization routines. +// + +template +void +MLASCALL +MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ); + /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index 96a2398796777..669c73d2b9c06 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -74,6 +74,7 @@ struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; bool SmoothSoftmax; + float Sink; const T* Input; T* Output; size_t N; @@ -850,6 +851,7 @@ Return Value: const size_t D = WorkBlock->D; const bool LogSoftmax = WorkBlock->LogSoftmax; const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; + const float Sink = WorkBlock->Sink; const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; @@ -880,11 +882,12 @@ Return Value: #else float Maximum = MlasReduceMaximumF32Kernel(Input, D); #endif - float NegativeMaximum = -Maximum; - if (SmoothSoftmax && NegativeMaximum > 0.0f) { - NegativeMaximum = 0.0f; + if (SmoothSoftmax && Sink > Maximum) { + Maximum = Sink; } + float NegativeMaximum = -Maximum; + // // Compute the exponential function for each element of the row (save to Temp if provided) and // compute the sum of these exponential functions. @@ -897,7 +900,7 @@ Return Value: #endif if (SmoothSoftmax) { - Accumulation += expf(NegativeMaximum); + Accumulation += expf(Sink + NegativeMaximum); } if (LogSoftmax) { @@ -1014,6 +1017,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ) /*++ @@ -1039,6 +1043,8 @@ Routine Description: SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. + Sink - Supplies the smooth factor to use in the softmax operation. + ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -1060,6 +1066,7 @@ Return Value: WorkBlock.Output = Output; WorkBlock.N = N; WorkBlock.D = D; + WorkBlock.Sink = Sink; // // Compute the number of target threads given the complexity of the softmax @@ -1097,6 +1104,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1110,6 +1118,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/dequantize.cpp b/onnxruntime/core/mlas/lib/dequantize.cpp new file mode 100644 index 0000000000000..175d3f668ac39 --- /dev/null +++ b/onnxruntime/core/mlas/lib/dequantize.cpp @@ -0,0 +1,395 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + dequantize.cpp + +Abstract: + + This module implements routines to dequantize buffers. + + The dequantization formula as specified in the ONNX operator documentation is: + + Output = (Input - ZeroPoint) * Scale + +--*/ + +#include "mlasi.h" + +// +// DequantizeLinear reference implementation using the C++ runtime. +// + +template +static +MLAS_FORCEINLINE +void +MlasDequantizeLinearRefImpl( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer with quantized data. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + Output[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } +} + +#if defined(MLAS_SSE2_INTRINSICS) +// Implementation for Intel SSE 2. Refer to the Intel Intrisics Guide: +// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + __m128i VectorS8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Sign-extend into 2 vectors of 8 int16s + __m128i SignMaskS8 = _mm_cmpgt_epi8(Zeros, VectorS8); // 0xFF for every negative byte in VectorS8 + __m128i VectorS16_0 = _mm_unpacklo_epi8(VectorS8, SignMaskS8); // [0 ... 7] + __m128i VectorS16_1 = _mm_unpackhi_epi8(VectorS8, SignMaskS8); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = _mm_sub_epi16(VectorS16_0, ZeroPointS16Vector); + VectorS16_1 = _mm_sub_epi16(VectorS16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + __m128i VectorU8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Zero-extend into 2 vectors of 8 uint16s + __m128i VectorU16_0 = _mm_unpacklo_epi8(VectorU8, Zeros); // [0 ... 7] + __m128i VectorU16_1 = _mm_unpackhi_epi8(VectorU8, Zeros); // [8 ... 15] + + // Subtract the zero-points as uint16s. Due to two's compliment, negative results can be reinterpreted as int16 + __m128i VectorS16_0 = _mm_sub_epi16(VectorU16_0, ZeroPointS16Vector); + __m128i VectorS16_1 = _mm_sub_epi16(VectorU16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearS8Kernel( +#else + MlasDequantizeLinearS8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearU8Kernel( +#else + MlasDequantizeLinearU8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} +#elif defined(MLAS_NEON64_INTRINSICS) +// Implementation for ARM64 NEON. Refer to the ARM instrinsics guide: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/ + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint); // Broadcast ZeroPoint (sign-extended to 16bits) + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + int8x16_t VectorS8 = vld1q_s8(Input); + + // Sign-extend into 2 vectors of 8 int16s + int16x8_t VectorS16_0 = vmovl_s8(vget_low_s8(VectorS8)); // [0 ... 7] + int16x8_t VectorS16_1 = vmovl_s8(vget_high_s8(VectorS8)); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = vsubq_s16(VectorS16_0, ZeroPointVector); + VectorS16_1 = vsubq_s16(VectorS16_1, ZeroPointVector); + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const uint8x8_t ZeroPointVector = vdup_n_u8(ZeroPoint); // Broadcast ZeroPoint to 8 uint8s + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + uint8x16_t VectorU8 = vld1q_u8(Input); + + // Subtract zero-point. The vsubl_u8 instruction zero-extends its arguments to uint16 first. + // The reinterpret from uint16x8 to int16x8 is actually a NOP. + int16x8_t VectorS16_0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(VectorU8), ZeroPointVector)); // [0 ... 7] + int16x8_t VectorS16_1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(VectorU8), ZeroPointVector)); // [8 ... 15] + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasDequantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasDequantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); +} +#else +// Implementation that uses the scalar reference implementation. + +template +void +MLASCALL +MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +{ + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + +template +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ); + +#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0af3cd2e33b02..0879d1b0ba510 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -747,6 +747,24 @@ void float Scale, int8_t ZeroPoint); +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_U8_KERNEL)( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_S8_KERNEL)( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -903,6 +921,8 @@ extern "C" { MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelAvx512F; @@ -1246,6 +1266,8 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL* DequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL* DequantizeLinearU8Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 45d3a876beb86..45bba5363d4f2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -285,6 +285,8 @@ Return Value: this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; + this->DequantizeLinearS8Kernel = MlasDequantizeLinearS8Kernel; + this->DequantizeLinearU8Kernel = MlasDequantizeLinearU8Kernel; #ifndef __APPLE__ #ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index dcc030cb3467d..fa645939a6395 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -89,23 +89,10 @@ uint64_t GetLuidKey(LUID luid) { return (uint64_t(luid.HighPart) << 32) | luid.LowPart; } -// Converts a wide string (up to 4 characters) representing a hardware ID component (e.g., "ABCD" from "VEN_ABCD") -// into a uint32_t. The conversion is done in a little-endian manner, meaning the first character -// of the string becomes the least significant byte of the integer, and the fourth character -// becomes the most significant byte. -uint32_t WStringToUint32Id(const std::wstring& vendor_name) { - uint32_t vendor_id = 0; - for (size_t i = 0; i < 4 && i < vendor_name.size(); ++i) { - // For little-endian, place each character at the appropriate byte position - // First character goes into lowest byte, last character into highest byte - vendor_id |= static_cast(vendor_name[i] & 0xFF) << (i * 8); - } - return vendor_id; -} - // returns info for display and processor entries. key is (vendor_id << 32 | device_id) // npus: (vendor_id << 32 | device_id) for devices we think are NPUs from DXCORE -std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus) { +std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus, + bool& have_remote_display_adapter) { std::unordered_map device_info; const GUID local_DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML = {0xb71b0d41, 0x1088, 0x422f, 0xa2, 0x7c, 0x2, 0x50, 0xb7, 0xd3, 0xa9, 0x88}; @@ -151,8 +138,7 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde if (auto idx = hardware_id.find(prefix); idx != std::wstring::npos) { auto id = hardware_id.substr(idx + prefix.size(), 4); if (id.size() == 4) { - // DXCore reports vendor and device IDs as 32-bit integer representations of the ASCII string. - return WStringToUint32Id(id); + return static_cast(std::stoul(id, nullptr, 16)); } } @@ -170,6 +156,11 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde // Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose. if (vendor_id == 0 && device_id == 0) { + static const std::wstring remote_display_adapter_id(L"RdpIdd_IndirectDisplay"); + if (guid == GUID_DEVCLASS_DISPLAY && remote_display_adapter_id == buffer) { + have_remote_display_adapter = true; + } + continue; } @@ -305,7 +296,7 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde } // returns LUID to DeviceInfo -std::unordered_map GetDeviceInfoD3D12() { +std::unordered_map GetDeviceInfoD3D12(bool have_remote_display_adapter) { std::unordered_map device_info; ComPtr factory; @@ -314,6 +305,8 @@ std::unordered_map GetDeviceInfoD3D12() { return device_info; } + UINT num_adapters = 0; + ComPtr adapter; for (UINT i = 0; factory->EnumAdapters1(i, adapter.ReleaseAndGetAddressOf()) != DXGI_ERROR_NOT_FOUND; ++i) { DXGI_ADAPTER_DESC1 desc; @@ -339,9 +332,12 @@ std::unordered_map GetDeviceInfoD3D12() { info.metadata[L"LUID"] = std::to_wstring(key); info.metadata[L"DxgiAdapterNumber"] = std::to_wstring(i); info.metadata[L"DxgiVideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; + + ++num_adapters; } - // iterate by high-performance GPU preference to add that info + // iterate by high-performance GPU preference to add that info. + UINT cur_adapter = 0; for (UINT i = 0; factory->EnumAdapterByGpuPreference( i, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, IID_PPV_ARGS(adapter.ReleaseAndGetAddressOf())) != DXGI_ERROR_NOT_FOUND; @@ -352,12 +348,41 @@ std::unordered_map GetDeviceInfoD3D12() { } uint64_t key = GetLuidKey(desc.AdapterLuid); - auto it = device_info.find(key); - if (it != device_info.end()) { - DeviceInfo& info = it->second; - info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); + if (it == device_info.end()) { + continue; } + + DeviceInfo& info = it->second; + + // try and drop the Microsoft Remote Display Adapter. it does not have the DXGI_ADAPTER_FLAG_SOFTWARE flag set + // and the vendor id, device id and description are the same as the real device. the LUID is different to the real + // device. + // Assumption: it will have the worst performance index of the devices we're considering so we only check the + // last adapter + if (num_adapters > 1 && have_remote_display_adapter && cur_adapter == num_adapters - 1) { + ComPtr output; + if (adapter->EnumOutputs(0, &output) == DXGI_ERROR_NOT_FOUND) { + // D3D_DRIVER_TYPE_WARP. Software based or disabled adapter. + // An adapter can be disabled in an RDP session. e.g. integrated GPU is disabled if there's a discrete GPU + + // if we have seen this vendor_id+device_id combination with a different LUID before we drop it. + if (std::any_of(device_info.begin(), device_info.end(), + [key, &info](const auto& entry) { + const auto& entry_info = entry.second; + return key != entry.first && + info.vendor_id == entry_info.vendor_id && + info.device_id == entry_info.device_id; + })) { + device_info.erase(key); + continue; + } + } + } + + info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); + + ++cur_adapter; } return device_info; @@ -497,10 +522,12 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } } - // d3d12 info. key is luid - std::unordered_map luid_to_d3d12_info = GetDeviceInfoD3D12(); // setupapi_info. key is vendor_id+device_id - std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus); + bool have_remote_display_adapter = false; // set if we see the RdpIdd_IndirectDisplay hardware ID. + std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus, have_remote_display_adapter); + + // d3d12 info. key is luid + std::unordered_map luid_to_d3d12_info = GetDeviceInfoD3D12(have_remote_display_adapter); // Ensure we have at least one CPU bool found_cpu = false; diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 2817dda9d0085..e123414b03b21 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -99,7 +99,7 @@ common::Status SoftmaxCPU(size_t N, float* Ydata, bool logarithmic, onnxruntime::concurrency::ThreadPool* thread_pool) { - MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool); + MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, 0.0f, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index 3359b2a69fe83..f7cc2523adbf6 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -445,7 +445,7 @@ void batched_update_scores_inplace(gsl::span scores, int64_t num_batches_in, } if (use_mlas) { - MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, threadpool); + MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, 0.0f, threadpool); } else { while (s < s_end) { gsl::span scores_for_batch(s, s + batch_size); diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index adb2aee171f39..c691be6ffd0e8 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" @@ -301,14 +302,31 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, const T* input, - const OutT* scale, OutT* output, const T* zero_point) { + const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { for (size_t m = 0; m < M; m++) { for (size_t k = 0; k < K; k++) { +#if defined(ORT_CLIENT_PACKAGE_BUILD) + // TODO: Only using multithreaded/SIMD DQ when ORT is built for client/on-device workloads. + // Make this the default behavior after more testing. + if constexpr (std::is_same_v || std::is_same_v) { + ParDequantizeLinearStd(input, output, N, scale[k], zero_point ? zero_point[k] : 0, thread_pool); + input += N; + output += N; + } else { + auto zp = zero_point ? static_cast(zero_point[k]) : 0; + auto sc = static_cast(scale[k]); + for (size_t n = 0; n < N; n++) { + *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); + } + } +#else + ORT_UNUSED_PARAMETER(thread_pool); auto zp = zero_point ? static_cast(zero_point[k]) : 0; auto sc = static_cast(scale[k]); for (size_t n = 0; n < N; n++) { *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); } +#endif // defined(ORT_CLIENT_PACKAGE_BUILD) } } } @@ -327,7 +345,8 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); if (zero_point) { for (size_t m = 0; m < M; m++) { for (size_t bd = 0; bd < K; bd += quant_block_size) { @@ -368,7 +387,8 @@ template struct DequantizeLinearApply { // per-tensor/layer or per-axis quantization void op(size_t M, size_t K, size_t N, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); size_t input_index = 0; for (size_t m = 0; m < M; m++) { @@ -394,7 +414,8 @@ struct DequantizeLinearApply { // Blocked quantization // TODO(fajin) : add mlas kernel to utilize multithreading, refer MlasDequantizeBlockwise. void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); size_t input_index = 0; if (zero_point) { @@ -440,36 +461,36 @@ struct DequantizeLinearApply { #if !defined(DISABLE_FLOAT8_TYPES) -#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ - template \ - struct DequantizeLinearApply { \ - /* Per-tensor/layer or per-axis quantization */ \ - void op(size_t M, size_t K, size_t N, \ - const T* input, const OutT* scale, OutT* output, const T*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd++) { \ - auto sc = scale[bd]; \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - } \ - } \ - /* Blocked quantization */ \ - void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ - const T* input, const OutT* scale, OutT* output, const T*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd += quant_block_size) { \ - for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - auto sc = static_cast(scale[bs]); \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - scale += N; \ - } \ - } \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ + template \ + struct DequantizeLinearApply { \ + /* Per-tensor/layer or per-axis quantization */ \ + void op(size_t M, size_t K, size_t N, \ + const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd++) { \ + auto sc = scale[bd]; \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + } \ + } \ + /* Blocked quantization */ \ + void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ + const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd += quant_block_size) { \ + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + auto sc = static_cast(scale[bs]); \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + scale += N; \ + } \ + } \ + } \ }; DEQUANTIZE_LINEAR_APPLY_FLOAT8(Float8E4M3FN) @@ -513,6 +534,7 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const auto to = x_scale.GetElementType(); const T* input = x.Data(); constexpr bool is_4bit = boost::mp11::mp_contains, T>::value; + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); if (to == ONNX_NAMESPACE::TensorProto::FLOAT) { const float* scale = x_scale.Data(); @@ -522,12 +544,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); @@ -537,12 +559,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } } else if (to == ONNX_NAMESPACE::TensorProto::BFLOAT16) { ORT_THROW("DequantizeLinear into BFLOAT16 is not implemented yet."); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 2de496a9168a0..f00bf51ae143d 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -313,8 +313,10 @@ CUDA_Provider* GetProvider() { // OrtEpApi infrastructure to be able to use the CUDA EP as an OrtEpFactory for auto EP selection. struct CudaEpFactory : OrtEpFactory { CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { + ort_version_supported = ORT_API_VERSION; GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -331,6 +333,11 @@ struct CudaEpFactory : OrtEpFactory { return factory->vendor.c_str(); } + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id; + } + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { return ORT_VERSION; } @@ -374,6 +381,7 @@ struct CudaEpFactory : OrtEpFactory { const OrtApi& ort_api; const std::string ep_name{kCudaExecutionProvider}; // EP name const std::string vendor{"Microsoft"}; // EP vendor name + uint32_t vendor_id{0x1414}; // Microsoft vendor ID }; extern "C" { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index a5066a41981e5..9611cb82d5a62 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -781,7 +781,10 @@ namespace Dml // this branch could be reached with a bad custom operator or malformed file. If // a legitimate case reaches here and DML needs to support a new input/output type // besides tensors, then remove the assert. - assert(false); + + // If the model has nodes that use Optional we will arrive here. It's a valid ONNX model but + // TryGetTensorDataType doesn't handle Optional. + // assert(false); nodeContainsSupportedDataTypes = false; return; } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 711d81186bad1..c5b6507ac847b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1304,7 +1304,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, CUDA_PINNED); + return std::make_unique(CUDA_PINNED, device_id); }, narrow(device_id_)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 86b684f8c6ebd..21947a22e2b92 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); bool needs_reshape = false; - const std::string reshape4d = input_names[0] + "_pre_reshape"; + const std::string reshape_prior_out = input_names[0] + "_prior_reshape"; if (input_shape.size() == 3) { needs_reshape = true; // build new_shape = {N, 1, C, L} @@ -245,25 +245,24 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra input_shape[1], input_shape[2]}; - const std::string reshape_node_name = "pre_reshape"; - QnnTensorWrapper rw( - reshape4d, + QnnTensorWrapper reshape_prior_tensor( + reshape_prior_out, QNN_TENSOR_TYPE_NATIVE, reshape_input_info.qnn_data_type, reshape_input_info.quant_param.Copy(), std::move(new_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)), - "Failed to add reshape-4d tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)), + "Failed to add reshape prior tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_prior", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {input_names[0]}, - {reshape4d}, + {reshape_prior_out}, {}, do_op_validation), - "Failed to create reshape-4d node."); - input_names[0] = reshape4d; + "Failed to create reshape prior node for pool op."); + input_names[0] = reshape_prior_out; input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]}; } @@ -446,9 +445,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); - const std::string pool_name = "poolmax2d"; - const std::string pool_out = real_out + "_post_reshape"; - const std::string post_reshape_node_name = "post_reshape"; + const std::string pool_out = real_out + "_reshape_after"; const std::string qnn_op = GetQnnOpType(op_type); TensorInfo output_info{}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); @@ -466,33 +463,34 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra "Failed to add tensor for pool_out"); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - pool_name, + utils::GetNodeName(node_unit) + "_pool2d", QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op, - {reshape4d}, + {reshape_prior_out}, {pool_out}, std::move(param_tensor_names), do_op_validation), - "Failed to create QNN Pool node for rank-3 input."); + "Failed to create pool node for rank-3 input."); std::vector final_shape3d = output_info.shape; - QnnTensorWrapper reshape_back_tensor( + QnnTensorWrapper reshape_after_tensor( real_out, tensor_type, output_info.qnn_data_type, output_info.quant_param.Copy(), std::move(final_shape3d)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)), + "Failed to add reshape after tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - post_reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_after", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {pool_out}, {real_out}, {}, do_op_validation), - "Failed to create reshape-back node."); + "Failed to create reshape after node for pool op."); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 2650316dd07ac..502ea86b689f4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace qnn { -// Operator which only need to hanle node inputs & outputs, no attributes or no need to handle attributes +// Operator which only need to handle node inputs & outputs, no attributes or no need to handle attributes class SimpleOpBuilder : public BaseOpBuilder { public: SimpleOpBuilder() : BaseOpBuilder("SimpleOpBuilder") {} @@ -38,7 +38,7 @@ class SimpleOpBuilder : public BaseOpBuilder { const logging::Logger& logger, bool do_op_validation) const ORT_MUST_USE_RESULT; - static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; + static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest", "linear"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; static constexpr std::array scatternd_supported_reduction = {"none", "add", "mul"}; }; @@ -60,8 +60,8 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, // To DO: Remove once QNN CPU supports ScatterND const auto qnn_backend_type = qnn_model_wrapper.GetQnnBackendType(); if (op_type == "ScatterND") { - ORT_RETURN_IF_NOT(qnn_backend_type == QnnBackendType::HTP, - "QNN EP only supports ScatterND op on HTP backend. Falling back to ORT CPU."); + ORT_RETURN_IF(qnn_backend_type == QnnBackendType::CPU, + "QNN EP does not support ScatterND op on CPU backend. Falling back to ORT CPU."); } // ONNX's Min, Max, and Sum operators accept a variable number of inputs (i.e., variadic). @@ -233,12 +233,12 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, std::string mode = node_helper.Get("mode", "linear"); Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT; mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; - if ("bilinear" == mode) { + if ("linear" == mode || "bilinear" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_BILINEAR; } else if ("nearest" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_NEAREST; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support bilinear & nearest."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support [linear, bilinear, nearest]."); } QnnParamWrapper mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_MODE, mode_qnn_scalar); param_tensor_names.push_back(mode_param.GetParamTensorName()); @@ -254,7 +254,7 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, } else if ("reflection" == padding_mode) { padding_mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_PADDING_MODE_REFLECTION; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support zeros, border & reflection."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support [zeros, border, reflection]."); } QnnParamWrapper padding_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_PADDING_MODE, padding_mode_qnn_scalar); param_tensor_names.push_back(padding_mode_param.GetParamTensorName()); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index d22edaf33eb1c..3dc103046424e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -839,6 +839,23 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord return Status::OK(); } +Status QnnBackendManager::SetContextPriority(ContextPriority context_priority) { + QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; + ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority, context_priority_config)); + + QnnContext_Config_t* configs[] = {&context_priority_config, nullptr}; + for (const auto& context_handle : contexts_) { + auto result = qnn_interface_.contextSetConfig(context_handle, (const QnnContext_Config_t**)configs); + ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to set context priority for context handle: ", context_handle); + } + + return Status::OK(); +} + +Status QnnBackendManager::ResetContextPriority() { + return SetContextPriority(context_priority_); +} + Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { if (true == context_created_) { LOGS_DEFAULT(INFO) << "Context created already."; @@ -1426,13 +1443,33 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, return Status::OK(); } -Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id, - uint32_t rpc_control_latency) { +Status QnnBackendManager::SetRpcPowerConfigs(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency, + uint32_t rpc_polling_time) { // This function is called in QNN EP's OnRunStart() even if QNN backend setup failed and the model is assigned // to a different EP. Therefore, we have to check that backend setup actually completed before trying to // set RPC control latency. Otherwise, this causes a segfault because the QNN backend library is unloaded. ORT_RETURN_IF_NOT(backend_setup_completed_, "Cannot set HTP RPC control latency if backend setup is not complete."); + + constexpr int kNumRpcPollingPowerConfigs = 2; + std::vector rpc_power_configs; + rpc_power_configs.reserve(kNumRpcPollingPowerConfigs); + + // Set rpc control latency here if (rpc_control_latency != 0) { + auto& rpc_control_latency_cfg = rpc_power_configs.emplace_back(); + rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; + rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; + } + + // Note: v68 does not support rpc polling mode + if (rpc_polling_time != 0) { + auto& rpc_polling_time_cfg = rpc_power_configs.emplace_back(); + rpc_polling_time_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; + rpc_polling_time_cfg.rpcPollingTimeConfig = rpc_polling_time; + } + + if (rpc_power_configs.size() > 0) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -1442,15 +1479,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_ "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; - // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. - constexpr int kNumRpcPollingPowerConfigs = 2; - std::vector rpc_power_configs(kNumRpcPollingPowerConfigs); - QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0]; - // v68 doesn't support this. - QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1]; - rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; - rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; - rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; std::vector perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs); status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 3e68df3024565..2a71c7391b180 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -159,8 +159,9 @@ class QnnBackendManager : public std::enable_shared_from_this Status SetHtpPowerConfig(uint32_t htp_power_config_client_id, HtpPerformanceMode htp_performance_mode); - Status SetRpcControlLatency(uint32_t htp_power_config_client_id, - uint32_t rpc_control_latency); + Status SetRpcPowerConfigs(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency, + uint32_t rpc_polling_time); const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; } @@ -219,6 +220,11 @@ class QnnBackendManager : public std::enable_shared_from_this // For each node name, a mapping to the context handle will be created void ProcessContextFromBinListAsync(Qnn_ContextHandle_t handle, void* notifyParam); + // Sets the context priority to the given value, if valid + Status SetContextPriority(ContextPriority context_priority); + // Resets the context priority to the session default as defined by context_priority_ + Status ResetContextPriority(); + private: Status LoadBackend(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 236447cc95c3d..3acb3347acee1 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1356,7 +1356,8 @@ QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* uint32_t device_id, uint32_t core_id, qnn::HtpPerformanceMode default_htp_performance_mode, - uint32_t default_rpc_control_latency) + uint32_t default_rpc_control_latency, + uint32_t default_rpc_polling_time) : qnn_backend_manager_(qnn_backend_manager) { Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_); is_htp_power_config_id_valid_ = rt.IsOK(); @@ -1367,9 +1368,10 @@ QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_, default_htp_performance_mode)); } - if (default_rpc_control_latency > 0) { - ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcControlLatency(htp_power_config_id_, - default_rpc_control_latency)); + if (default_rpc_control_latency > 0 || default_rpc_polling_time > 0) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcPowerConfigs(htp_power_config_id_, + default_rpc_control_latency, + default_rpc_polling_time)); } } } @@ -1400,7 +1402,8 @@ QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContex if (context_state_.retired_context_pool.empty()) { uint32_t core_id = 0; context = std::make_shared(qnn_backend_manager_.get(), device_id_, core_id, - default_htp_performance_mode_, default_rpc_control_latency_); + default_htp_performance_mode_, default_rpc_control_latency_, + default_rpc_polling_time_); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -1468,15 +1471,21 @@ Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_optio LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; } + uint32_t rpc_polling_time = 0; + if (qnn::HtpPerformanceMode::kHtpBurst != htp_performance_mode) { + rpc_polling_time = 9999; + } + if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), htp_performance_mode)); } - if (rpc_control_latency > 0) { - ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcControlLatency(GetPerThreadContext().GetHtpPowerConfigId(), - rpc_control_latency)); + if (rpc_control_latency > 0 || rpc_polling_time > 0) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcPowerConfigs(GetPerThreadContext().GetHtpPowerConfigId(), + rpc_control_latency, + rpc_polling_time)); } } @@ -1545,4 +1554,38 @@ OrtDevice QNNExecutionProvider::GetOrtDeviceByMemType(OrtMemType /* em_type */) return default_device_; } +Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span keys, + gsl::span values) { + if (keys.size() != values.size()) { + LOGS_DEFAULT(ERROR) << "SetEpDynamicOptions: number of keys (" << keys.size() + << ") does not equal number of values (" << values.size() << ")."; + } + auto key_it = keys.begin(); + auto value_it = values.begin(); + + while (key_it != keys.end() && value_it != values.end()) { + std::string key(*key_it); + std::string value(*value_it); + + if (key == kOrtEpDynamicOptionsWorkloadType) { + if (value == "Default") { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->ResetContextPriority()); + } else if (value == "Efficient") { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetContextPriority(qnn::ContextPriority::LOW)); + } else { + LOGS_DEFAULT(ERROR) << "Invalid EP Workload Type: " << value; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid EP Workload Type."); + } + } else { + LOGS_DEFAULT(ERROR) << "EP Dynamic Option \"" << key << "\" is not currently supported."; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported EP Dynamic Option"); + } + + key_it++; + value_it++; + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 06f9726ae96cf..6adf613932d66 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -57,6 +57,9 @@ class QNNExecutionProvider : public IExecutionProvider { OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; + Status SetEpDynamicOptions(gsl::span keys, + gsl::span value) override; + private: std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, @@ -96,6 +99,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t device_id_ = 0; qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; uint32_t default_rpc_control_latency_ = 0; + uint32_t default_rpc_polling_time_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; bool stop_share_ep_contexts_ = false; @@ -116,7 +120,8 @@ class QNNExecutionProvider : public IExecutionProvider { PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, uint32_t device_id, uint32_t core_id, qnn::HtpPerformanceMode default_htp_performance_mode, - uint32_t default_rpc_control_latency); + uint32_t default_rpc_control_latency, + uint32_t default_rpc_polling_time); ~PerThreadContext(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index c679ea1adb286..785177ce37788 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -125,8 +125,10 @@ struct QnnEpFactory : OrtEpFactory { OrtHardwareDeviceType hw_type, const char* qnn_backend_type) : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} { + ort_version_supported = ORT_API_VERSION; GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -142,7 +144,12 @@ struct QnnEpFactory : OrtEpFactory { static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); - return factory->vendor.c_str(); + return factory->ep_vendor.c_str(); + } + + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_vendor_id; } static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { @@ -195,8 +202,9 @@ struct QnnEpFactory : OrtEpFactory { } const OrtApi& ort_api; - const std::string ep_name; // EP name - const std::string vendor{"Microsoft"}; // EP vendor name + const std::string ep_name; // EP name + const std::string ep_vendor{"Microsoft"}; // EP vendor name + uint32_t ep_vendor_id{0x1414}; // Microsoft vendor ID // Qualcomm vendor ID. Refer to the ACPI ID registry (search Qualcomm): https://uefi.org/ACPI_ID_List const uint32_t vendor_id{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 90a4294fb47f0..1e9fafe8aa323 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -7,6 +7,25 @@ #include "tensorrt_execution_provider_custom_ops.h" #include "tensorrt_execution_provider.h" +// The filename extension for a shared library is different per platform +#ifdef _WIN32 +#define LIBRARY_PREFIX +#define LIBRARY_EXTENSION ORT_TSTR(".dll") +#elif defined(__APPLE__) +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".dylib" +#else +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".so" +#endif + +#ifdef _WIN32 +#define ORT_DEF2STR_HELPER(x) L#x +#else +#define ORT_DEF2STR_HELPER(X) #X +#endif +#define ORT_DEF2STR(x) ORT_DEF2STR_HELPER(x) + namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose); @@ -58,8 +77,31 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& // Get all registered TRT plugins from registry LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ..."; TensorrtLogger trt_logger = GetTensorrtLogger(false); - initLibNvInferPlugins(&trt_logger, ""); + try { + void* library_handle = nullptr; + const auto& env = onnxruntime::GetDefaultEnv(); +#if NV_TENSORRT_MAJOR < 10 + auto full_path = env.GetRuntimePath() + + PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION); +#else +#ifdef _WIN32 + auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin_" ORT_DEF2STR(NV_TENSORRT_MAJOR)) LIBRARY_EXTENSION); +#else + auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION ORT_TSTR("." ORT_DEF2STR(NV_TENSORRT_MAJOR))); +#endif +#endif + + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, false, &library_handle)); + bool (*dyn_initLibNvInferPlugins)(void* logger, char const* libNamespace); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "initLibNvInferPlugins", (void**)&dyn_initLibNvInferPlugins)); + if (!dyn_initLibNvInferPlugins(&trt_logger, "")) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library was found but was not able to initialize default plugins."; + } + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugins successfully loaded."; + } catch (const std::exception&) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library is not on the path and is therefore ignored"; + } int num_plugin_creator = 0; auto plugin_creators = getPluginRegistry()->getAllCreators(&num_plugin_creator); std::unordered_set registered_plugin_names; diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index e8140a4d59eab..113a3f31be7f9 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -193,27 +193,21 @@ class BucketCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - pending_buffers_.emplace_back(buffer); - } + auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); - void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { - for (auto& buffer : pending_buffers_) { - auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); - auto it = buckets_.find(buffer_size); - if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { - it->second.emplace_back(buffer); - } else { - wgpuBufferRelease(buffer); - } + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.emplace_back(buffer); + } else { + wgpuBufferRelease(buffer); } + } - pending_buffers_.clear(); + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { + // no-op } ~BucketCacheManager() { - for (auto& buffer : pending_buffers_) { - wgpuBufferRelease(buffer); - } for (auto& pair : buckets_) { for (auto& buffer : pair.second) { wgpuBufferRelease(buffer); @@ -242,7 +236,6 @@ class BucketCacheManager : public IBufferCacheManager { } std::unordered_map buckets_limit_; std::unordered_map> buckets_; - std::vector pending_buffers_; std::vector buckets_keys_; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 7f92ea4ed3776..313a96ba25509 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -52,10 +52,28 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T1", CastOpTypeConstraints()) .TypeConstraint("T2", CastOpTypeConstraints()), Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); ONNX_OPERATOR_KERNEL_EX( Cast, kOnnxDomain, - 19, + 23, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", CastOpTypeConstraints()) diff --git a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc index f13e86c185928..9f07e2d2a3988 100644 --- a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc @@ -146,24 +146,24 @@ Status ScatterND::ComputeInternal(ComputeContext& context) const { const auto* updates = context.Input(2); const auto& input_shape = input->Shape(); const auto& indices_shape = indices->Shape(); - auto indices_rank = indices_shape.NumDimensions(); - auto last_index_dimension = static_cast(indices_shape[indices_rank - 1]); - auto num_updates_elements = static_cast(input_shape.SizeFromDimension(last_index_dimension)); - // TODO: support bool with components 4. - const size_t components = 1; - auto output_size = static_cast((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components); auto* output = context.Output(0, input_shape); - if (output_size == 0) { - // If the output tensor is empty, we can return early. - return Status::OK(); - } - MLDataType data_type = input->DataType(); const void* source = input->DataRaw(); void* target = output->MutableDataRaw(); // If source and target pointers are not equal (non-inplace operation), we need to copy the data. if (target != source) { ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output)); } + if (indices_shape.Size() == 0) { + // If the indices are empty, we can return early. + return Status::OK(); + } + auto indices_rank = indices_shape.NumDimensions(); + auto last_index_dimension = static_cast(indices_shape[indices_rank - 1]); + auto num_updates_elements = static_cast(input_shape.SizeFromDimension(last_index_dimension)); + // TODO: support bool with components 4. + const size_t components = 1; + auto output_size = static_cast((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components); + MLDataType data_type = input->DataType(); ScatterNDProgram program(reduction_, data_type); program .CacheHint(static_cast(reduction_)) diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 39432db5113d1..7e8b434431781 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -172,8 +172,8 @@ Status Slice::ComputeInternal(ComputeContext& context) const { } if (step < 0) { // we are slicing in reverse - start = std::clamp(start, int64_t{0}, dim_value - 1); - end = std::clamp(end, int64_t{-1}, dim_value - 1); + start = dim_value > 0 ? std::clamp(start, int64_t{0}, dim_value - 1) : 0; + end = dim_value > 0 ? std::clamp(end, int64_t{-1}, dim_value - 1) : -1; // note that we are flipping start and end to switch to forward step signs.push_back(-1); steps.push_back(static_cast(-step)); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 460d220ecf1b9..6e09f494f4a8d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -123,7 +123,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip); @@ -455,7 +457,9 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), - KERNEL_CREATE_INFO(19, Cast), + KERNEL_CREATE_INFO_VERSIONED(19, 20, Cast), + KERNEL_CREATE_INFO_VERSIONED(21, 22, Cast), + KERNEL_CREATE_INFO(23, Cast), // // activations BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/README.md b/onnxruntime/core/providers/webgpu/wgsl_templates/README.md index c1a62e7fa7858..6bd2f98cc5713 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/README.md +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/README.md @@ -64,7 +64,7 @@ This section includes instructions for how to use the template system in the dev 1. Create WGSL template files in `.wgsl.template` extension. - [Reference: Template Syntax](https://github.com/fs-eire/wgsl-template?tab=readme-ov-file#template-syntax) - - [Reference: Built-in Utilities](#Utilities) + - [Reference: Built-in Utilities](https://github.com/fs-eire/wgsl-template?tab=readme-ov-file#Utilities) - [Example: Pad](../tensor/pad.wgsl.template) 2. In the implementation of `YourProgram::GenerateShaderCode()`, load and use the generated template files. @@ -117,4 +117,4 @@ This section includes instructions for how to use the template system in the dev 1. Build ORT once with dynamic template mode 2. Launch wgsl-gen in watch mode 3. Run ORT to debug/validate the shader - 4. Make changes to the template files, and repeat step (3) + 4. Make changes to the template files, and repeat step (c) diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json index 7cde6c17f54e9..df1940ed6416b 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json @@ -9,13 +9,13 @@ "version": "1.0.0", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.3" + "@fs-eire/wgsl-template": "^0.1.13" } }, "node_modules/@fs-eire/wgsl-template": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.10.tgz", - "integrity": "sha512-F5qQZxNweZ3ZD3d9RNc/g3nTiW7jyaAVi7SlMOL4wOfXh+Nm/qca2DISNTf3kjpVqkoazMJGbZ6TPQ4a/vjw0g==", + "version": "0.1.13", + "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.13.tgz", + "integrity": "sha512-SOQjVCQCUmXb9qYr2E3CKNs88/FzINuhFJiobBEkSAsyKtJby9oFWGZnrEO+hIl/oDTLA01LbjiDxuf6TGHE/w==", "license": "MIT", "dependencies": { "minimist": "^1.2.8" diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json index 34831ccddeb33..246e7365531e0 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json @@ -10,6 +10,6 @@ "author": "", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.3" + "@fs-eire/wgsl-template": "^0.1.13" } } diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index e821265fff80d..142d64caa64aa 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -99,69 +99,93 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n return true; } -// Check if all input tensor ranks of the given node are supported by WebNN. -bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { - const std::string_view op_type = node.OpType(); - const auto it = op_inputs_map.find(op_type); - if (it == op_inputs_map.end()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type << "] is not found in the op inputs map."; +// Check if a single input's rank of an ONNX op is supported by corresponding WebNN op. +bool IsInputRankSupported(const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view input_name, + const size_t input_rank, + const std::string_view node_name, + const logging::Logger& logger) { + const std::string webnn_op_type_str(webnn_op_type); + const std::string input_name_str(input_name); + + if (wnn_limits[webnn_op_type_str].isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type: [" << webnn_op_type + << "] is not defined in WebNN MLOpSupportLimits."; return false; } - const auto& input_defs = node.InputDefs(); - const std::string_view webnn_op_type = it->second.opType; - const std::string webnn_op_type_str(webnn_op_type); + const emscripten::val input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - for (const auto& input : it->second.inputs) { - if (static_cast(input.index) >= input_defs.size() || input_defs[input.index] == nullptr) { - LOGS(logger, VERBOSE) << "Input index [" << input.index - << "] for operator type [" << op_type - << "], corresponding WebNN op type [" << webnn_op_type - << "], WebNN input name [" << input.name - << "] is invalid."; - return false; - } + if (input_limits.isUndefined()) { + LOGS(logger, VERBOSE) << "Node name: [" << node_name + << "], WebNN op type: [" << webnn_op_type + << "], input [" << input_name + << "]: limits are not defined in WebNN MLOpSupportLimits."; + return false; + } - std::vector input_shape; - if (!GetShape(*input_defs[input.index], input_shape, logger)) { - return false; - } + const emscripten::val rank_range = input_limits["rankRange"]; + if (rank_range.isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type + << "] input [" << input_name + << "]: missing 'rankRange' attribute."; + return false; + } - const std::string input_name_str(input.name); - if (wnn_limits[webnn_op_type_str].isUndefined() || - wnn_limits[webnn_op_type_str][input_name_str].isUndefined()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name " << input.name - << " is not defined in wnn_limits."; - return false; - } + const emscripten::val min_val = rank_range["min"]; + const emscripten::val max_val = rank_range["max"]; + if (min_val.isUndefined() || max_val.isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type + << "] input [" << input_name + << "]: its 'rankRange' limits is missing valid 'min' or 'max' attributes."; + return false; + } - const auto& input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - if (input_limits["rankRange"].isUndefined()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name " << input.name - << "'s rankRange is not defined."; - return false; + size_t min_rank = min_val.as(); + size_t max_rank = max_val.as(); + if (input_rank < min_rank || input_rank > max_rank) { + LOGS(logger, VERBOSE) << "Node name: [" << node_name + << "] WebNN op type [" << webnn_op_type + << "] input [" << input_name << "] rank " << input_rank + << " is not in supported range [" << min_rank << ", " << max_rank << "]"; + return false; + } + + return true; +} + +bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { + const std::string_view onnx_op_type = node.OpType(); + const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type); + + if (webnn_op_type.empty()) { + LOGS(logger, VERBOSE) << "ONNX op type: [" << onnx_op_type << "]'s corresponding WebNN op is not found."; + return false; + } + + std::vector inputs; + if (!GetWebNNOpInputs(onnx_op_type, inputs, logger)) { + return false; + } + + const auto& input_defs = node.InputDefs(); + + for (const auto& input : inputs) { + // If it is an optional input and is absent, skip. + if (!TensorExists(input_defs, input.index)) { + continue; } - int input_dim_size = static_cast(input_shape.size()); - int min_rank = input_limits["rankRange"]["min"].as(); - int max_rank = input_limits["rankRange"]["max"].as(); - - if (input_dim_size < min_rank || input_dim_size > max_rank) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name: " << input.name - << ", input size " << input_dim_size - << " is not in supported range [" << min_rank << ", " << max_rank << "]"; + std::vector shape; + if (!GetShape(*input_defs[input.index], shape, logger) || + !IsInputRankSupported(wnn_limits, webnn_op_type, input.name, + shape.size(), + node.Name(), logger)) { return false; } } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d59788600f997..50e361ede221e 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -216,6 +216,13 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger); +bool IsInputRankSupported(const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view input_name, + const size_t input_rank, + const std::string_view node_name, + const logging::Logger& logger); + // Get a set of nodes supported by WebNN EP. std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, @@ -244,6 +251,33 @@ inline std::string_view GetWebNNOpType(const std::string_view onnx_op_type) { return (it != op_inputs_map.end()) ? it->second.opType : ""; } +// Get corresponding input name of WebNN op type by ONNX op type from op_input_map +inline std::string_view GetWebNNInputName(const std::string_view onnx_op_type, const int input_index) { + const auto it = op_inputs_map.find(onnx_op_type); + + if (it != op_inputs_map.end()) { + for (const auto& input : it->second.inputs) { + if (input.index == input_index) { + return input.name; + } + } + } + + return ""; +} + +inline bool GetWebNNOpInputs(const std::string_view onnx_op_type, + std::vector& inputs, + const logging::Logger& logger) { + const auto it = op_inputs_map.find(onnx_op_type); + if (it == op_inputs_map.end()) { + LOGS(logger, VERBOSE) << "WebNN op inputs not found for op type: " << onnx_op_type; + return false; + } + inputs = it->second.inputs; + return true; +} + bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index fc630af8cf1e3..fdf1709d87bac 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -18,10 +18,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - WebnnDeviceType device_type, const logging::Logger& logger) const override; }; // Add operator related. @@ -65,20 +61,6 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. -bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer& /* initializers */, - const Node& node, - WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - return true; -} - void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index b0ec006db6986..3c8e7fa34f7ed 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -62,13 +62,12 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, int32_t input_type; if (!GetType(input, input_type, logger)) return false; - const std::string_view webnn_op_type = GetWebNNOpType(op_type); - if (webnn_op_type.empty()) - return false; + const std::string_view webnn_op_type = GetWebNNOpType(op_type); const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits, - webnn_input_name, "input", logger); + webnn_input_name, "input", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 280ffc83eae89..851dc373923ac 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -73,9 +73,10 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod return false; } - std::string webnn_input_name = op_type == "PRelu" ? "input" : "a"; + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A"; - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index 8589237617745..db5e8cd51656c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -75,7 +75,8 @@ bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index b9383a63fe307..e0bfb3bd682e8 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -324,7 +324,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N x_zero_point = model_builder.CreateOrGetConstant(x_type, 0); } - // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to default value 1.0f. // The x_zero_point must be a scalar and the scale input should have the same shape as the zero point input. // So the x_scale must be a scalar too. x_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); diff --git a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc index 7528d9ad2ff51..f3c392b608e45 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc @@ -77,10 +77,6 @@ bool CumSumOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const std::string axis_name = GetTensorName(input_defs, 1); // Inputs contain optional 'axis' input. const auto* init = graph_viewer.GetConstantInitializer(axis_name); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc index c22dd9e97bb1a..37a00fcb12abd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -21,11 +21,6 @@ class DropoutOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; // Add operator related. @@ -65,26 +60,13 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val options = emscripten::val::object(); options.set("label", output_defs[1]->Name() + "_identity"); // Add additional identity op in case the mask is the output of a WebNN graph, - // beacuse WebNN does not support a constant operand as output. + // because WebNN does not support a constant operand as output. emscripten::val mask_output = model_builder.GetBuilder().call("identity", one_constant, options); model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output)); } return Status::OK(); } -// Operator support related. -bool DropoutOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - return true; -} - void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc index e5b4fcddc4221..6aa760c0f4baf 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -28,6 +28,8 @@ class EinsumOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& /* node */, const emscripten::val& /* wnn_limits */, + const logging::Logger& /* logger */) const override; }; // Helper functions, thanks for DML EP's OperatorHelper. @@ -42,12 +44,6 @@ enum class RecognizedOperatorType { Total, }; -struct RecognizedOperatorInfo { - RecognizedOperatorType recognized_operator_type; - std::initializer_list component_ranks; - std::initializer_list label_indices; -}; - struct Component { uint32_t label_index_begin; uint32_t label_index_end; @@ -598,7 +594,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } } - // tranpose input + // transpose input std::vector permutation(input_labels.size()); for (uint32_t idx = 0; idx < input_labels.size(); idx++) { if (idx != diagonal_idx_1 && idx != diagonal_idx_2) { @@ -620,7 +616,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options_trilu.set("upper", false); output = model_builder.GetBuilder().call("triangular", output, options_trilu); // tril - // reducesum to achieve the diagonal values + // reduceSum to achieve the diagonal values std::vector input_shape; std::vector reduced_axes; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); @@ -700,12 +696,6 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - if (input_defs.size() > 2) { - // TODO: Support more than two inputs. - LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; - return false; - } - NodeAttrHelper helper(node); const auto equation = helper.Get("equation", std::string(" ")); std::vector label_indices; @@ -724,13 +714,6 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } - RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, - output_dimensions); - if (recognized_operator_type == RecognizedOperatorType::None) { - LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; - return false; - } - return true; } @@ -738,9 +721,14 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + if (input_defs.size() > 2) { + // TODO: Support more than two inputs. + LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; + return false; + } + const std::string_view op_type = node.OpType(); - int32_t input0_type; - int32_t input1_type; + int32_t input0_type, input1_type; bool has_input1 = TensorExists(input_defs, 1); if (!GetType(*input_defs[0], input0_type, logger) || @@ -754,6 +742,13 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod return false; } + std::vector input0_shape; + std::vector input1_shape; + if (!GetShape(*input_defs[0], input0_shape, logger) || + (has_input1 && !GetShape(*input_defs[1], input1_shape, logger))) { + return false; + } + NodeAttrHelper helper(node); const auto equation = helper.Get("equation", std::string(" ")); std::vector label_indices; @@ -770,17 +765,54 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, output_dimensions); + std::string_view decomposed_op_type; if (recognized_operator_type == RecognizedOperatorType::None) { LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; return false; - } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { - // Map to WebNN's gemm or matmul - return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger); + } else if (recognized_operator_type == RecognizedOperatorType::Multiply) { + decomposed_op_type = "Mul"; } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { - return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger); - } else { - return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger); + decomposed_op_type = "ReduceSum"; + } else if (recognized_operator_type == RecognizedOperatorType::Diagonal) { + decomposed_op_type = "Trilu"; + } else if (recognized_operator_type == RecognizedOperatorType::Transpose) { + decomposed_op_type = "Transpose"; + } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { + decomposed_op_type = "MatMul"; + } else { // Identity + // For the Identity case, we simply forward the input to the output without any modification. + return true; + } + + const std::string_view wnn_input0_name = GetWebNNInputName(decomposed_op_type, 0); + const std::string_view decompose_wnn_op_type = GetWebNNOpType(decomposed_op_type); + if (decompose_wnn_op_type.empty() || + !IsDataTypeSupportedByWebNNOp(op_type, decompose_wnn_op_type, input0_type, + wnn_limits, wnn_input0_name, "inputs", logger) || + !IsInputRankSupported(wnn_limits, decompose_wnn_op_type, wnn_input0_name, + input0_shape.size(), node.Name(), logger)) { + return false; + } + + if (has_input1) { + const std::string_view wnn_input1_name = GetWebNNInputName(decomposed_op_type, 1); + return IsDataTypeSupportedByWebNNOp(op_type, decompose_wnn_op_type, input1_type, + wnn_limits, wnn_input1_name, "inputs", logger) && + IsInputRankSupported(wnn_limits, decompose_wnn_op_type, wnn_input1_name, + input1_shape.size(), node.Name(), logger); } + + return true; +} + +bool EinsumOpBuilder::HasSupportedOutputsImpl(const Node& /* node */, + const emscripten::val& /* wnn_limits */, + const logging::Logger& /* logger */) const { + // The Einsum op produces output with the same data type as its input. + // Therefore, checking the output data type is unnecessary. + // This override prevents calling the base class implementation, as the base implementation + // would return false due to Einsum being a decomposed op. + return true; } void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index 06beb56415609..ae4c3705fdb2e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -56,14 +56,14 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const N const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type; - int32_t indices_type; + int32_t data_type, indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index 9200c596c0e53..af508c2800f4b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -61,14 +61,14 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& n const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type; - int32_t indices_type; + int32_t data_type, indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index d84c70032e1d1..7111a8f6beaa3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -20,8 +20,6 @@ class GatherOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -50,38 +48,20 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. - -bool GatherOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const auto rank = input_shape.size(); - if (rank < 1) { - LOGS(logger, VERBOSE) << "Gather only supports input shapes >= 1D, but input is " - << rank << "d shape"; - return false; - } - - return true; -} - bool GatherOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t input_type; - int32_t indices_type; + int32_t input_type, indices_type; + if (!GetType(input, input_type, logger) || !GetType(indices, indices_type, logger)) return false; return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 02f46c85d1d06..7af17fdc5db78 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -91,7 +91,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); std::vector a_zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[2], a_zero_point_shape, logger), "Cannot get shape of a_zero_point"); - // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to default value 1.0f. // The scale input should have the same shape as the zero point input. a_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, @@ -268,11 +268,45 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - if (op_type == "MatMulInteger") { - // The first decomposed op of MatMulInteger is DequantizeLinear, and so - // we only need to ensure it supports the input0_type. + if (op_type == "Gemm") { + return IsInputRankSupportedByOp(node, wnn_limits, logger) && + IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); + } else if (op_type == "MatMulInteger") { + // Check up to 4 inputs for MatMulInteger + for (size_t i = 0; i < input_defs.size(); ++i) { + std::vector shape; + if (!GetShape(*input_defs[i], shape, logger)) { + return false; + } + + // We made workaround to support 1D for input A and B, skip further checks if they are 1D + if (i <= 1 && shape.size() == 1) { + continue; + } + + // For DequantizeLinear, input indices: 0 (x), 1 (scale), 2 (zero_point) + if (!IsInputRankSupported(wnn_limits, "dequantizeLinear", + (i < 2) ? "input" : "zeroPoint", + shape.size(), node.Name(), logger)) { + return false; + } + } return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger); - } else { + } else { // MatMul + for (int i = 0; i < 2; ++i) { + std::vector shape; + if (!GetShape(*input_defs[i], shape, logger)) { + return false; + } + + if (shape.size() == 1) { + continue; + } + + if (!IsInputRankSupported(wnn_limits, "matmul", (i == 0) ? "a" : "b", shape.size(), node.Name(), logger)) { + return false; + } + } return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index dfe80dd419092..95e75a3083cc2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -219,7 +219,8 @@ bool GruOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); + return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 42940083cad8e..55d468c4843cb 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -91,8 +91,10 @@ bool LogicalOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no } } + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); std::string onnx_input_name = op_type == "Not" ? "X" : "A"; - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc index 8936bda875aef..e8aab725375ad 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc @@ -21,8 +21,6 @@ class LRNOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, @@ -128,11 +126,10 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool LRNOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { +bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) return false; @@ -143,12 +140,6 @@ bool LRNOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } - return true; -} - -bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, - const emscripten::val& wnn_limits, const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); int32_t input_type = 0; if (!GetType(*input_defs[0], input_type, logger)) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 09e584bc66f8a..04d59e2f30d15 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -242,7 +242,8 @@ bool LstmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc index 111d03571e974..9ab403b7051d2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc @@ -48,7 +48,7 @@ void MatMulNBitsBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons // DequantizeLinear + Transpose + MatMul. Given that the CPU EP currently only supports // 4-bit quantization, we only handle 4-bit quantization here. // -// To align with WebNN's dequantizeLinear op contraints, the following transformations are +// To align with WebNN's dequantizeLinear op constraints, the following transformations are // required for MatMulNBits inputs: // 1. B: must be a constant initializer and registered as a 'uint4' WebNN constant with shape // [N, n_blocks_per_col, blob_size * 2]. @@ -159,10 +159,6 @@ bool MatMulNBitsBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const logging::Logger& logger) const { const auto& name = node.Name(); const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } // Inputs B and zero_points (if present) must be initializers if (!graph_viewer.GetConstantInitializer(input_defs[1]->Name())) { // B @@ -193,6 +189,10 @@ bool MatMulNBitsBuilder::HasSupportedInputsImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } int32_t A_type = 0; int32_t B_type = 0; @@ -227,10 +227,13 @@ bool MatMulNBitsBuilder::HasSupportedInputsImpl(const GraphViewer&, return false; } - // We only support 4-bit quantization, which is represented as the uint4 data type in WebNN. - // Ensure that uint4 is supported. + // Data type: Currently, only 4-bit quantization is supported, represented as the uint4 data type in WebNN. + // Ensure that the uint4 data type is supported by WebNN's dequantizeLinear op. + // Input rank: Only the rank of the first input (A) is flexible. Verify that its rank is supported by + // WebNN's matmul op. return IsDataTypeSupportedByOp("DequantizeLinear", ONNX_NAMESPACE::TensorProto_DataType_UINT4, - wnn_limits, "input", "x", logger); + wnn_limits, "input", "x", logger) && + IsInputRankSupported(wnn_limits, "matmul", "a", input_shape.size(), node.Name(), logger); } bool MatMulNBitsBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 4e4014e3553ea..9f5ac6ef15735 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -20,8 +20,6 @@ class MaxMinOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node& node, - WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -68,25 +66,6 @@ Status MaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool MaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - const auto& op_type = node.OpType(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - if (input_defs.size() < 1) { - LOGS(logger, VERBOSE) << op_type << " requires at least one input (data)"; - return false; - } - - return true; -} - bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -108,7 +87,8 @@ bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 148eacac98e4a..9fb643f055ef3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -46,28 +46,14 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - std::vector scale_shape; const size_t scale_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 2 : 1; - ORT_RETURN_IF_NOT(GetShape(*input_defs[scale_input_index], scale_shape, logger), "Cannot get scale shape"); - const auto scale_size = scale_shape.size(); - // Except LayerNormalization, other normalization ops' scale input should be 1-D. - if (op_type == "LayerNormalization") { - ORT_RETURN_IF_NOT(scale_size >= 1 && scale_size <= rank, - "The scale size should be less than or equal to input size."); - } else { - ORT_RETURN_IF_NOT(scale_size == 1, "The scale size should be one."); - } - emscripten::val scale = model_builder.GetOperand(input_defs[scale_input_index]->Name()); options.set("scale", scale); const size_t bias_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 3 : 2; emscripten::val bias = emscripten::val::undefined(); if (TensorExists(input_defs, bias_input_index)) { - // Bias input exists, and bias's shape should be the same as scale's shape. - std::vector bias_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[bias_input_index], bias_shape, logger), "Cannot get bias shape"); - ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape."); + // Bias input exists. bias = model_builder.GetOperand(input_defs[bias_input_index]->Name()); options.set("bias", bias); } @@ -279,12 +265,6 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - LOGS(logger, VERBOSE) << "Cannot get input shape."; - return false; - } - const auto& output_defs = node.OutputDefs(); if (op_type == "SkipSimplifiedLayerNormalization") { if (output_defs.size() > 4) { @@ -316,33 +296,28 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); - int32_t input0_type; // input data type - int32_t input1_type; // scale data type - int32_t input2_type; // B data type - int32_t input3_type; // mean data type - int32_t input4_type; // var data type - bool has_input2 = TensorExists(input_defs, 2); - bool has_input3 = TensorExists(input_defs, 3); - bool has_input4 = TensorExists(input_defs, 4); - - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger) || - (has_input2 && !GetType(*input_defs[2], input2_type, logger)) || - (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || - (has_input4 && !GetType(*input_defs[4], input4_type, logger))) { - return false; - } - std::vector input_types = {input0_type, input1_type}; - if (has_input2) { - input_types.push_back(input2_type); - } - if (has_input3) { - input_types.push_back(input3_type); + std::vector input_types; + bool all_types_valid = true; + + // Iterate through all inputs and check their existence and types + for (size_t i = 0; i <= input_defs.size(); ++i) { + if (TensorExists(input_defs, i)) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + all_types_valid = false; + break; + } + input_types.push_back(input_type); + } } - if (has_input4) { - input_types.push_back(input4_type); + + // Return false if any input type is invalid + if (!all_types_valid) { + return false; } + + // Check if all input data types are the same if (!AreDataTypesSame(op_type, input_types, logger)) { return false; } @@ -355,13 +330,29 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); const std::string_view webnn_input_name = GetWebNNOpFirstInputName(decomposed_op_type); if (!IsDataTypeSupportedByWebNNOp( - op_type, webnn_op_type, input0_type, wnn_limits, webnn_input_name, "input", logger)) { + op_type, webnn_op_type, input_types[0], wnn_limits, webnn_input_name, "input", logger)) { return false; } } - return true; + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + // It's complicated to check all the decomposed ops' input rank support. + // Ensure at least the first input rank is supported by the decomposed ops (pow and div accept the first input). + return IsInputRankSupported(wnn_limits, "pow", "a", input_shape.size(), node.Name(), logger) && + IsInputRankSupported(wnn_limits, "div", "a", input_shape.size(), node.Name(), logger); } else { - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + bool is_data_type_supported = IsDataTypeSupportedByOp(op_type, input_types[0], wnn_limits, "input", "X", logger); + if (op_type == "InstanceNormalization") { + // Skip input rank check for InstanceNormalization, as we will reshape the input to 4D if necessary. + return is_data_type_supported; + } + + // For other ops, check both data type and input rank compatibility. + bool is_input_rank_supported = IsInputRankSupportedByOp(node, wnn_limits, logger); + return is_input_rank_supported && is_data_type_supported; } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index f2a3f08b73148..5d921c5176a64 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -133,20 +133,6 @@ bool PoolOpBuilder::IsOpSupportedImpl(const GraphViewer&, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& op_type = node.OpType(); - const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - const auto input_size = input_shape.size(); - if (input_size != 4) { - LOGS(logger, VERBOSE) - << op_type << " only supports rank-4 tensor, input [" - << input_defs[0]->Name() << "] has actual dim count " << input_size; - return false; - } - NodeAttrHelper helper(node); if (op_type == "AveragePool" || op_type == "LpPool" || op_type == "MaxPool") { if (helper.Get("kernel_shape", std::vector{1, 1}).size() != 2) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index dd25fb9bf9315..053c41773db40 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -167,7 +167,8 @@ bool QDQOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && + IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) && (!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger)); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index a3a0397eda4a3..6ea9b0a440d93 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -128,16 +128,10 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - const auto& op_type = node.OpType(); const std::string axes_name = GetTensorName(input_defs, 1); // If the optional input 'axes' is provided, it must be an initializer. if (!axes_name.empty() && !graph_viewer.GetConstantInitializer(axes_name)) { - LOGS(logger, VERBOSE) << "Input axes of " << op_type << " must be a constant"; + LOGS(logger, VERBOSE) << "Input axes of " << node.OpType() << " must be a constant"; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index 8cbb381e0f53e..0444ae3afb56a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -79,11 +79,6 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const auto& perm_name = input_defs[1]->Name(); const auto* perm_init = graph_viewer.GetConstantInitializer(perm_name); if (!perm_init) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index 893ca9d2419c7..37071b1030e11 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -285,7 +285,7 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build sign_buffer.set(1, 1.0f); } else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { if (model_builder.IsFloat16ArrayAvailable()) { - // Float16Array is avaliable - use Float16Array. + // Float16Array is available - use Float16Array. sign_buffer = emscripten::val::global("Float16Array").new_(2); sign_buffer.set(0, -1.0f); sign_buffer.set(1, 1.0f); diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index f894e8bfbd517..c2974bd988f6b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -71,7 +71,6 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; - const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -85,8 +84,11 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const return false; } + const std::string_view op_type = node.OpType(); + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index e61ac3dcc9617..a7788cfd847e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -63,7 +63,6 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; - const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -76,9 +75,10 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& if (data_type != updates_type) { return false; } - + const std::string_view op_type = node.OpType(); return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 8853891ff8ed6..5efbfe932c602 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -136,10 +136,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const No const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } if (input_defs.size() < 3) { LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data, starts, ends) but got " @@ -166,10 +162,17 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& input = *input_defs[0]; - const std::string_view op_type = node.OpType(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + int32_t input_type; - if (!GetType(input, input_type, logger)) + if (!GetType(input, input_type, logger)) { return false; + } + + const std::string_view op_type = node.OpType(); // If there is step < 0, check data type support of reverse. if (TensorExists(input_defs, 4)) { @@ -178,13 +181,15 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con if (!init || !ReadIntArrayFrom1DTensor(*init, steps, graph_viewer, logger)) return false; if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) { - if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) { + if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger) || + !IsInputRankSupported(wnn_limits, "reverse", "input", input_shape.size(), node.Name(), logger)) { return false; } } } - return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger); + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 23e73bb8f1e74..99d137f81864c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -18,11 +18,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -46,20 +41,6 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. - -bool SoftmaxOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - return true; -} - void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 1ba6df9febf14..7e34e35ebac16 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -127,9 +127,6 @@ bool SqueezeUnsqueezeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewe const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; if (input_defs.size() < 1) { LOGS(logger, ERROR) << op_type << " has no input tensor"; diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 7a7f64b1ec96d..8973757a24e99 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -66,7 +66,8 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no return false; } - return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); + return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc index 29b232026d7df..24d96588559ae 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc @@ -77,15 +77,6 @@ bool TileOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, return false; } - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - if (input_shape.empty()) { - LOGS(logger, VERBOSE) << "Tile does not support empty input shape"; - return false; - } - return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc index 5a267557b9454..7a4d172c556fa 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc @@ -76,15 +76,6 @@ bool TriangularOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const auto input_size = input_shape.size(); - if (input_size < 2) { - LOGS(logger, VERBOSE) << "Triangular only supports input size >= 2D shape, input is " - << input_size << "d shape"; - return false; - } const std::string diagonal_name = GetTensorName(input_defs, 1); // Inputs contain optional 'diagonal' input. diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 5e860eea7cac9..1c30fed7a7916 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -47,6 +47,7 @@ constexpr std::array supported_fallback // Use ONNX-to-ONNX op mapping to improve the search complexity for WebNN ops in the op_inputs_map. const std::map> decomposed_op_map = { {"ConvInteger", {"Cast", "Conv", "DequantizeLinear"}}, + {"Einsum", {"MatMul", "Mul", "ReduceSum", "Reshape", "Transpose", "Trilu"}}, {"GroupQueryAttention", {"Add", "Cast", "Concat", "CumSum", "Div", "Expand", "Less", "MatMul", "Reshape", "ScatterND", "Softmax", "Transpose", "Where"}}, @@ -139,7 +140,7 @@ const std::unordered_map op_inputs_map = { {"Mul", {"mul", {{0, "a"}, {1, "b"}}}}, {"Pow", {"pow", {{0, "a"}, {1, "b"}}}}, {"Concat", {"concat", {{0, "inputs"}}}}, - {"Not", {"logicalNot", {{0, "input"}}}}, + {"Not", {"logicalNot", {{0, "a"}}}}, {"Flatten", {"reshape", {{0, "input"}}}}, {"LpPool", {"l2Pool2d", {{0, "input"}}}}, {"Reshape", {"reshape", {{0, "input"}}}}, @@ -159,7 +160,6 @@ const std::unordered_map op_inputs_map = { {"Softsign", {"softsign", {{0, "input"}}}}, {"Unsqueeze", {"reshape", {{0, "input"}}}}, {"Or", {"logicalOr", {{0, "a"}, {1, "b"}}}}, - {"Einsum", {"matmul", {{0, "a"}, {1, "b"}}}}, {"HardSwish", {"hardSwish", {{0, "input"}}}}, {"LeakyRelu", {"leakyRelu", {{0, "input"}}}}, {"MatMul", {"matmul", {{0, "a"}, {1, "b"}}}}, diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 4468831181d42..d2cd0639affd0 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -78,7 +78,7 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; const bool is_float16array_available_ = !emscripten::val::global("Float16Array").isUndefined() && - emscripten::val::global("Float16Array").hasOwnProperty("from"); + !emscripten::val::global("Float16Array")["from"].isUndefined(); emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index d910e3ea74b57..59b0992d827e1 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -128,6 +128,35 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + const ORTCHAR_T* output_directory, + const ORTCHAR_T* model_name) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + std::string output_dir = PathToUTF8String(output_directory); + if (output_dir.empty()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output directory: path is empty"); + } + + std::string model_name_str = ToUTF8String(model_name); + if (model_name_str.empty()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid model name: string is empty"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_dir, model_name_str)); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(output_directory); + ORT_UNUSED_PARAMETER(model_name); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExternalInitializersFile, _In_ OrtModelCompilationOptions* ort_model_compile_options, const ORTCHAR_T* external_initializers_file_path, @@ -248,6 +277,7 @@ static constexpr OrtCompileApi ort_compile_api = { // End of Version 22 - DO NOT MODIFY ABOVE &OrtCompileAPI::ModelCompilationOptions_SetFlags, + &OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 5f11b894f2004..93cc5dbf20fce 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -30,5 +30,7 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options, size_t flags); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index daccd24453371..a0904c32011a7 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -16,6 +16,10 @@ struct ForwardToFactory { return static_cast(this_ptr)->GetVendor(); } + static uint32_t ORT_API_CALL GetVendorId(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->GetVendorId(); + } + static const char* ORT_API_CALL GetVersion(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetVersion(); } diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index b289010cc6c5b..fa4ef2515ca92 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -14,17 +14,19 @@ namespace onnxruntime { using Forward = ForwardToFactory; -EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, +EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, GetSupportedFunc&& get_supported_func, CreateFunc&& create_func) : ep_name_{ep_name}, vendor_{vendor}, + vendor_id_{vendor_id}, get_supported_func_{std::move(get_supported_func)}, create_func_{create_func} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; OrtEpFactory::GetVendor = Forward::GetVendor; + OrtEpFactory::GetVendorId = Forward::GetVendorId; OrtEpFactory::GetVersion = Forward::GetVersion; OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index 087c0c60f8f4e..ee08e2233c529 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -33,12 +33,13 @@ class EpFactoryInternal : public OrtEpFactory { const OrtSessionOptions* session_options, const OrtLogger* logger, std::unique_ptr* ep)>; - EpFactoryInternal(const std::string& ep_name, const std::string& vendor, + EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, GetSupportedFunc&& get_supported_func, CreateFunc&& create_func); const char* GetName() const noexcept { return ep_name_.c_str(); } const char* GetVendor() const noexcept { return vendor_.c_str(); } + uint32_t GetVendorId() const noexcept { return vendor_id_; } const char* GetVersion() const noexcept; OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -67,6 +68,7 @@ class EpFactoryInternal : public OrtEpFactory { private: const std::string ep_name_; // EP name library was registered with const std::string vendor_; // EP vendor name + const uint32_t vendor_id_; // EP vendor ID const GetSupportedFunc get_supported_func_; // function to return supported devices const CreateFunc create_func_; // function to create the EP instance diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index 25f70f7549a16..ce5736f601b45 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -61,7 +61,8 @@ std::unique_ptr EpLibraryInternal::CreateCpuEp() { }; std::string ep_name = kCpuExecutionProvider; - auto cpu_factory = std::make_unique(ep_name, "Microsoft", get_supported, create_cpu_ep); + auto cpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, + get_supported, create_cpu_ep); return std::make_unique(std::move(cpu_factory)); } @@ -122,7 +123,8 @@ std::unique_ptr EpLibraryInternal::CreateDmlEp() { return nullptr; }; - auto dml_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_dml_ep); + auto dml_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, + is_supported, create_dml_ep); return std::make_unique(std::move(dml_factory)); } @@ -170,7 +172,8 @@ std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { return nullptr; }; - auto webgpu_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_webgpu_ep); + auto webgpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, + is_supported, create_webgpu_ep); return std::make_unique(std::move(webgpu_factory)); } diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc index 73423a4744576..70937bdc5d3e8 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/ep_library_provider_bridge.cc @@ -72,6 +72,7 @@ Status EpLibraryProviderBridge::Load() { auto internal_factory = std::make_unique(factory->GetName(factory), factory->GetVendor(factory), + factory->GetVendorId(factory), is_supported_fn, create_fn); factory_ptrs_.push_back(internal_factory.get()); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 86a61a4d0ee74..f147242da668f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -423,7 +423,13 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, { if (!external_intra_op_thread_pool_) { bool allow_intra_op_spinning = +#if !defined(ORT_CLIENT_PACKAGE_BUILD) session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "1") == "1"; +#else + // default KOrtSessionOptionsConfigAllowIntraOpSpinning to "0" for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "0") == "1"; +#endif OrtThreadPoolParams to = session_options_.intra_op_param; std::basic_stringstream ss; if (to.name) { @@ -461,7 +467,13 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL) { if (!external_inter_op_thread_pool_) { bool allow_inter_op_spinning = +#if !defined(ORT_CLIENT_PACKAGE_BUILD) session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "1") == "1"; +#else + // default kOrtSessionOptionsConfigAllowInterOpSpinning to "0" for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "0") == "1"; +#endif OrtThreadPoolParams to = session_options_.inter_op_param; to.auto_set_affinity = to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL; std::basic_stringstream ss; diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index 5de0f03fafc08..bbb110033f54c 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -72,8 +72,8 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod if (log_manager != nullptr && log_manager->HasDefaultLogger()) { const logging::Logger& logger = log_manager->DefaultLogger(); LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size() - << ") exceeds limit of " << ConfigOptions::kMaxKeyLength << " characters." - << "ORT will still generated the expected output file, but EPs will see an empty " + << ") exceeds limit of " << ConfigOptions::kMaxValueLength << " characters." + << "ORT will still generate the expected output file, but EPs will see an empty " << "output model path in SessionOption's ConfigOptions."; } } @@ -98,6 +98,36 @@ Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr a return Status::OK(); } +Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::string& output_directory, + const std::string& model_name) { + if (output_directory.empty() || model_name.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir or model_name is empty."); + } + + std::filesystem::path output_dir_path(output_directory); + if (output_dir_path.has_filename() && output_dir_path.extension() == "") { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir is not a valid directory."); + } + + std::filesystem::path ctx_model_path = output_directory / std::filesystem::path(model_name); + + if (ctx_model_path.string().size() <= ConfigOptions::kMaxValueLength) { + ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, + ctx_model_path.string().c_str())); + } else { + logging::LoggingManager* log_manager = env_.GetLoggingManager(); + if (log_manager != nullptr && log_manager->HasDefaultLogger()) { + const logging::Logger& logger = log_manager->DefaultLogger(); + LOGS(logger, WARNING) << "output_directory length with model_name length together exceeds limit of " + << ConfigOptions::kMaxValueLength << " characters." + << "ORT will still generate the expected output file, but EPs will see an empty " + << "output path in SessionOption's ConfigOptions."; + } + } + + return Status::OK(); +} + Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_model) { ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry( kOrtSessionOptionEpContextEmbedMode, embed_ep_context_in_model ? "1" : "0")); @@ -146,7 +176,7 @@ Status ModelCompilationOptions::ResetOutputModelSettings() { ep_context_gen_options.output_model_buffer_ptr = nullptr; ep_context_gen_options.output_model_buffer_size_ptr = nullptr; ep_context_gen_options.output_model_buffer_allocator = nullptr; - return session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ""); + return Status::OK(); } Status ModelCompilationOptions::CheckInputModelSettings() const { diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index f96f0317cdaca..2824df863013d 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -72,6 +72,16 @@ class ModelCompilationOptions { Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); + /// + /// Sets information relate to EP context binary file. + /// EP use this information to decide the location and context binary file name. + /// Used while compiling model with input and output in memory buffer + /// + /// The folder path to the generated context binary file + /// Model name used to decide the context binary file name: [model_name]_[ep].bin + /// Status indicating potential error + Status SetEpContextBinaryInformation(const std::string& output_directory, const std::string& model_name); + /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext /// nodes. Defaults to false (dumped to file). diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index e7f60fd48a14f..db2a62c77d1bc 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2591,6 +2591,29 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets) { + API_IMPL_BEGIN + if (num_operator_sets == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_operator_sets' argument is NULL"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetNumOperatorSets(*num_operator_sets)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetOperatorSets, _In_ const OrtGraph* graph, + _Out_writes_(num_operator_sets) const char** domains, + _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets) { + API_IMPL_BEGIN + gsl::span domains_span(domains, num_operator_sets); + gsl::span versions_span(opset_versions, num_operator_sets); + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetOperatorSets(domains_span, versions_span)); + + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs) { API_IMPL_BEGIN if (num_inputs == nullptr) { @@ -2691,6 +2714,91 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _O API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, + _In_ const OrtNode** nodes, + _In_ size_t num_nodes, + _Outptr_ OrtGraph** dst_graph) { + API_IMPL_BEGIN + + if (num_nodes == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0"); + } + + const EpGraph* ep_graph = EpGraph::ToInternal(src_graph); + if (ep_graph == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph."); + } + const Graph& graph = ep_graph->GetGraphViewer().GetGraph(); + + // Create a GraphViewer with filtered info + std::unique_ptr indexed_sub_graph = std::make_unique(); + std::unique_ptr metadef = std::make_unique(); + metadef->name = "sub_graph"; + metadef->since_version = 1; + std::unordered_set outputs; + std::unordered_set initializers; + + auto add_inputs = [&](ConstPointerContainer> defs) { + for (const auto* def : defs) { + if (def->Exists()) { + // not the output of a previous node + if (outputs.count(def->Name()) == 0) { + metadef->inputs.push_back(def->Name()); + } else { + // consumed by node so no longer subgraph output + // NOTE: Ignoring edge case where a node output is an overall graph output AND a node input + outputs.erase(def->Name()); + } + + if (graph.IsInitializedTensor(def->Name())) { + initializers.insert(def); + } + } + } + }; + + auto add_node = [&](const Node& node) { + indexed_sub_graph->nodes.push_back(node.Index()); + add_inputs(node.InputDefs()); + add_inputs(node.ImplicitInputDefs()); + + for (const auto* def : node.OutputDefs()) { + outputs.insert(def->Name()); + } + }; + + // Add nodes + for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { + const OrtNode* ort_node = nodes[node_idx]; + const EpNode* ep_node = EpNode::ToInternal(ort_node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph."); + } + add_node(ep_node->GetInternalNode()); + } + + // Add initializers + for (auto& initializer : initializers) { + metadef->constant_initializers.push_back(initializer->Name()); + } + + // Add outputs + for (auto& output : outputs) { + metadef->outputs.push_back(output); + } + + indexed_sub_graph->SetMetaDef(std::move(metadef)); + auto graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); + + std::unique_ptr result; + ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result)); + + *dst_graph = result.release(); + + return nullptr; + API_IMPL_END +} + // // OrtNode // @@ -2922,10 +3030,11 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetNumSubgraphs, _In_ const OrtNode* node, _Ou } ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs) { + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, + _Out_writes_opt_(num_subgraphs) const char** attribute_names) { API_IMPL_BEGIN gsl::span graphs_span(subgraphs, num_subgraphs); - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span)); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span, attribute_names)); return nullptr; API_IMPL_END } @@ -2943,6 +3052,23 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetGraph, _In_ const OrtNode* node, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Node_GetEpName, _In_ const OrtNode* node, + _Outptr_result_maybenull_ const char** out) { + API_IMPL_BEGIN + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); + } + + const EpNode* ep_node = EpNode::ToInternal(node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetEpName."); + } + + *out = ep_node->GetEpName().c_str(); + return nullptr; + API_IMPL_END +} + ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #ifdef ENABLE_TRAINING_APIS if (version >= 13 && version <= ORT_API_VERSION) @@ -3594,6 +3720,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ValueInfo_IsFromOuterScope, &OrtApis::Graph_GetName, &OrtApis::Graph_GetOnnxIRVersion, + &OrtApis::Graph_GetNumOperatorSets, + &OrtApis::Graph_GetOperatorSets, &OrtApis::Graph_GetNumInputs, &OrtApis::Graph_GetInputs, &OrtApis::Graph_GetNumOutputs, @@ -3603,6 +3731,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetNumNodes, &OrtApis::Graph_GetNodes, &OrtApis::Graph_GetParentNode, + &OrtApis::Graph_GetGraphView, &OrtApis::Node_GetId, &OrtApis::Node_GetName, &OrtApis::Node_GetOperatorType, @@ -3622,6 +3751,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumSubgraphs, &OrtApis::Node_GetSubgraphs, &OrtApis::Node_GetGraph, + &OrtApis::Node_GetEpName, &OrtApis::GetRunConfigEntry, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index cbacbfce0740d..9ab927006c320 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -631,6 +631,10 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); +ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); +ORT_API_STATUS_IMPL(Graph_GetOperatorSets, _In_ const OrtGraph* graph, + _Out_writes_(num_operator_sets) const char** domains, + _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets); ORT_API_STATUS_IMPL(Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs); ORT_API_STATUS_IMPL(Graph_GetInputs, _In_ const OrtGraph* graph, _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs); @@ -645,6 +649,8 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); +ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, + _Outptr_ OrtGraph** subgraph); // OrtNode ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id); @@ -671,8 +677,10 @@ ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOp ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, + _Out_writes_opt_(num_subgraphs) const char** attribute_names); ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); +ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index e8d62ab86f517..211bf8b2d15a4 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -22,7 +22,13 @@ namespace onnxruntime { namespace { bool MatchesEpVendor(const OrtEpDevice* d) { - // TODO: Would be better to match on Id. Should the EP add that in EP metadata? + // match on vendor id if provided + uint32_t factory_vendor_id = d->ep_factory->GetVendorId(d->ep_factory); + if (factory_vendor_id != 0 && d->device->vendor_id == factory_vendor_id) { + return true; + } + + // match on vendor name return d->device->vendor == d->ep_vendor; } diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 0172902bdf4e2..f7d5cdb98aa1d 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -1001,4 +1001,53 @@ struct BlockedQuantizeLinear { #endif +/** + * @brief Run MlasDequantizeLinear in parallel, with provided thread pool + */ + +template +void ParDequantizeLinearStd(const InputQuantType* input, + float* output, + size_t num_elems, + float scale, + InputQuantType zero_point, + concurrency::ThreadPool* thread_pool) { + constexpr std::ptrdiff_t block_size = 128; + const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; + const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), + static_cast(block_size * sizeof(float)), + static_cast(block_size) * 2.0}; + concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto begin_idx = begin * block_size; + auto end_idx = std::min(static_cast(num_elems), end * block_size); + MlasDequantizeLinear(&(input[begin_idx]), &(output[begin_idx]), end_idx - begin_idx, scale, zero_point); + }); +} + +// Note: this doesn't use MLAS kernel. There are currently no MLAS kernels for fp16 QuantizeLinear or DequantizeLinear. +template +void ParDequantizeLinearStd(const InputQuantType* input, + MLFloat16* output, + size_t num_elems, + MLFloat16 scale, + InputQuantType zero_point, + concurrency::ThreadPool* thread_pool) { + constexpr std::ptrdiff_t block_size = 128; + const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; + const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), + static_cast(block_size * sizeof(MLFloat16)), + static_cast(block_size) * 2.0}; + + const int32_t zp_s32 = static_cast(zero_point); + const float sc_f32 = scale.ToFloat(); + + concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto begin_idx = begin * block_size; + auto end_idx = std::min(static_cast(num_elems), end * block_size); + for (; begin_idx != end_idx; ++begin_idx) { + output[begin_idx] = MLFloat16(static_cast(static_cast(input[begin_idx]) - zp_s32) * sc_f32); + } + }); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/util/thread_utils.h b/onnxruntime/core/util/thread_utils.h index d63d620dbc321..0b99723b2c75b 100644 --- a/onnxruntime/core/util/thread_utils.h +++ b/onnxruntime/core/util/thread_utils.h @@ -19,7 +19,13 @@ struct OrtThreadPoolParams { bool auto_set_affinity = false; // If it is true, the thread pool will spin a while after the queue became empty. +#if !defined(ORT_CLIENT_PACKAGE_BUILD) bool allow_spinning = true; +#else + // default allow_spinning to false for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + bool allow_spinning = false; +#endif // It it is non-negative, thread pool will split a task by a decreasing block size // of remaining_of_total_iterations / (num_of_threads * dynamic_block_base_) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 9a297e451213a..e3303dac6c8c5 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -42,7 +42,7 @@ def __init__(self, **data: dict[str, Any]): for k, v in data.items(): if not isinstance(k, str): raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.") - if k != "axis" and not isinstance(v, (int, str, np.ndarray)): + if k != "axis" and not isinstance(v, (int, str, np.ndarray, float)): raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.") if k == "axis" and not isinstance(v, int) and v is not None: raise TypeError(f"Axis value must be an int or None, not {type(v)}.") diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index fbeae39c39d21..319c5aa468f7e 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -86,6 +86,7 @@ "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, "BatchNormalization": QDQNormalization, + "TopK": QDQDirect8BitOp, } diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index fe93f5cd358bf..8711e368cd1e6 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -269,42 +269,48 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): attention_last_node = reshape_qkv add_qk = "" + causal_mask_nodes_1 = None + causal_mask_nodes_2 = None if add_mask is not None: - # 4D Add after Q x K' - add_qk_nodes = self.model.match_parent_path( - add_mask, - [ - "Where", - "Sub", - "Cast", - "Expand", - "Unsqueeze", - "Unsqueeze", - "Reshape", - "Reshape", - "Cast", - ], - [1, 2, 1, 0, 0, 0, 0, 0, 0], - ) - if add_qk_nodes is not None: + if add_mask.input[1] == "attention_mask": add_qk = add_mask.input[1] else: - # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path - # of computing causal mask. - causal_mask_nodes_1 = self.model.match_parent_path( - add_mask, - ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0, 0], - ) - # If the model is exported with batch_size == 1, there is no Concat node - causal_mask_nodes_2 = self.model.match_parent_path( + # 4D Add after Q x K' + add_qk_nodes = self.model.match_parent_path( add_mask, - ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0], + [ + "Where", + "Sub", + "Cast", + "Expand", + "Unsqueeze", + "Unsqueeze", + "Reshape", + "Reshape", + "Cast", + ], + [1, 2, 1, 0, 0, 0, 0, 0, 0], ) - if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None: - logger.debug("fuse_attention: failed to match causal mask subgraph") - return + if add_qk_nodes is not None: + add_qk = add_mask.input[1] + else: + # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path + # of computing causal mask. + causal_mask_nodes_1 = self.model.match_parent_path( + add_mask, + ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0, 0], + ) + # If the model is exported with batch_size == 1, there is no Concat node + causal_mask_nodes_2 = self.model.match_parent_path( + add_mask, + ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0], + ) + + if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None: + logger.debug("fuse_attention: failed to match causal mask subgraph") + return new_node = self.create_attention_node( mask_index=None, @@ -320,7 +326,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): output=attention_last_node.output[0], add_qk_str=add_qk, scale=None, - causal=(add_mask is not None), + causal=(causal_mask_nodes_1 is not None) or (causal_mask_nodes_2 is not None), ) if new_node is None: logger.debug("fuse_attention: failed to create fused node") diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 6bd698f8b75b4..e16957eab80a1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,7 +1,7 @@ onnxscript>=0.2.3 optimum>=1.14.1 optree -transformers==4.48.0 +transformers==4.52.1 torch>=2.7.0 onnx==1.17.0 datasets>=2.8.0 diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index ac696ff3788aa..e092285d57358 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -410,7 +410,7 @@ def export_onnx_models( precision == Precision.FLOAT16, model.config.encoder_attention_heads, model.config.d_model, - model.config.num_hidden_layers, + model.config.decoder_layers, use_external_data_format, use_gpu=use_gpu, provider=provider, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index f1758cc52280f..37fc72cd26e07 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,5 +1,5 @@ torch>=2.7.0 -transformers>=4.52.3 +transformers==4.52.3 openai-whisper==20240927 ffmpeg-python datasets diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index fadf271ae913b..e10e616d35d38 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -187,7 +187,7 @@ def input_names(self): *list( chain.from_iterable( (f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}") - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -205,7 +205,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -214,8 +214,7 @@ def output_names(self): "logits", *list( chain.from_iterable( - (f"present_key_self_{i}", f"present_value_self_{i}") - for i in range(self.config.num_hidden_layers) + (f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 26dc3aee7018b..cd81edc1001be 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -127,7 +127,7 @@ def output_names(self): *list( chain.from_iterable( (f"present_key_cross_{i}", f"present_value_cross_{i}") - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -143,7 +143,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index f66aa22eb0972..a236c4da1738e 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -763,7 +763,7 @@ def optimize_onnx( is_float16: bool, num_attention_heads: int, hidden_size: int, - num_layers: int, + num_decoder_layers: int, use_external_data_format: bool = False, use_gpu: bool = False, provider: str = "cpu", @@ -801,7 +801,7 @@ def optimize_onnx( m = add_cache_indirection_to_mha(m, past_seq_len_name) if output_qk: - m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_layers, 2))) + m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_decoder_layers, 2))) m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py index 0b0882eface72..8937fea900d14 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py @@ -94,14 +94,14 @@ def get_sample_past_key_values( torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] cross_attention_kv_caches = [ ( torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches) @@ -187,7 +187,7 @@ def get_sample_QKs( # noqa: N802 torch.rand( batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] return QKs diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py index a7c0d3538b8da..4dd5d7de1752b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py @@ -156,7 +156,7 @@ def input_names(self): "alignment_heads", "sot_sequence_length", "segment_length", - *[f"cross_qk_{i}" for i in range(self.config.num_hidden_layers)], + *[f"cross_qk_{i}" for i in range(self.config.decoder_layers)], ] return input_names diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index b498c40079f48..44b3f9a213abf 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -226,7 +226,7 @@ OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) { /*static*/ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) { + OrtEpGraphSupportInfo* graph_support_info) noexcept { ExampleEp* ep = static_cast(this_ptr); size_t num_nodes = 0; @@ -290,7 +290,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes) { + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { ExampleEp* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; @@ -328,6 +328,12 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); + const char* ep_name = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetEpName(fused_nodes[0], &ep_name)); + if (std::strncmp(ep_name, "example_ep", 11) != 0) { + return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on"); + } + // Associate the name of the fused node with our MulKernel. const char* fused_node_name = nullptr; RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); @@ -354,7 +360,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const /*static*/ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos) { + size_t num_node_compute_infos) noexcept { (void)this_ptr; for (size_t i = 0; i < num_node_compute_infos; i++) { delete node_compute_infos[i]; diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index b8c63f39438ba..dfebcc52a0caf 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -31,14 +31,14 @@ class ExampleEp : public OrtEp, public ApiPtrs { private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; static OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info); + OrtEpGraphSupportInfo* graph_support_info) noexcept; static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes); + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos); + size_t num_node_compute_infos) noexcept; OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index d4895102b0bf1..19a44008b8c97 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -14,6 +14,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; @@ -87,6 +88,12 @@ const char* ORT_API_CALL ExampleEpFactory::GetVendorImpl(const OrtEpFactory* thi return factory->vendor_.c_str(); } +/*static*/ +uint32_t ORT_API_CALL ExampleEpFactory::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id_; +} + /*static*/ const char* ORT_API_CALL ExampleEpFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index fda77f12c4814..72fa1c1301841 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -21,6 +21,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; @@ -53,6 +54,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name + const uint32_t vendor_id_{0xB357}; // EP vendor ID const std::string ep_version_{"0.1.0"}; // EP version // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 7b77ca8c69225..4c3f9e8dd4dbd 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -527,18 +527,20 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop if (std::is_same_v) { #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_WEBGPU execution_providers.push_back(DefaultWebGpuExecutionProvider()); -#endif - RunTest(opts, std::move(execution_providers)); +#endif } else { #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 60498e6510ec2..17e829e37f729 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -1,16 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include #include #include #include #include +#include #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/graph/ep_api_types.h" +#include "core/graph/graph_proto_serializer.h" + +#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL +#include "core/providers/utils/ort_graph_to_proto.h" #include "test/ep_graph/test_ep_graph_utils.h" #include "test/util/include/api_asserts.h" @@ -26,6 +34,7 @@ namespace test { // forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent // to a graph represented by the internal ORT GraphViewer class. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph); +static void Check_Graph_GetSubgraph(const OrtGraph& api_graph); // // Tests @@ -68,6 +77,178 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { + // The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test. + // The model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph_v2.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + +static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {1, 1, 28, 28}; + std::vector input_data(28 * 28, 0.5f); + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'Input3' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); + ort_input_names.push_back("Input3"); + + // Run session and get outputs + std::array output_names{"Plus214_Output_0"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 10); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + +// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_Mnist) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/mnist.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("mnist_serialized.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to GraphProto. Save initializers to external file. + std::string ext_ini_file_path = "mnist_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::ModelProto model_proto; + OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, handle_initializer_data); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + RunMNISTModel(original_model_path, output_original); + RunMNISTModel(serialized_model_path, output_serialized); + + EXPECT_EQ(output_serialized, output_original); +} + +static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {1}; + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'if_cond_input' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, &input_cond, 1, input_shape.data(), input_shape.size())); + ort_input_names.push_back("if_cond_input"); + + // Run session and get outputs + std::array output_names{"if_cond_output"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 1); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + +// Test serializing an OrtGraph to GraphProto. The model has 3 layers of nested subgraphs. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_3LayerSubgraphs) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/three_layer_nested_subgraph.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("three_layer_nested_subgraph_serialized.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to ModelProto (all initializers stored within TensorProtos). + ONNX_NAMESPACE::ModelProto model_proto; + OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + { + Run3LayerModel(original_model_path, true, output_original); + Run3LayerModel(serialized_model_path, true, output_serialized); + EXPECT_EQ(output_serialized, output_original); + } + + { + Run3LayerModel(original_model_path, false, output_original); + Run3LayerModel(serialized_model_path, false, output_serialized); + EXPECT_EQ(output_serialized, output_original); + } +} + // // Utils for traversing an OrtGraph and checking against GraphViewer. // @@ -307,6 +488,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span nodes(num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); + + // Select a half of nodes to create a OrtGraph + size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); + std::vector selected_nodes(num_selected_nodes); + + for (size_t i = 0; i < num_selected_nodes; i++) { + selected_nodes[i] = nodes[i]; + } + + OrtGraph* sub_graph; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); + + // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. + // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. + const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer(); + std::unique_ptr model = std::make_unique(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger()); + auto model_proto = std::make_unique(model->ToProto()); + GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + const char* graph_name = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); + std::string name = graph_name; + name += "_half.onnx"; + + // Dump the graph for debugging + // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); + // model_proto->SerializeToOstream(&dump); + + ort_api.ReleaseGraph(sub_graph); +} + // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { @@ -470,9 +693,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } // Check node subgraphs - std::vector> node_subgraphs = node->GetSubgraphs(); + std::unordered_map> node_subgraphs_map = + node->GetAttributeNameToSubgraphMap(); - if (!node_subgraphs.empty()) { + if (!node_subgraphs_map.empty()) { // Check node's implicit inputs to its subgraph nodes. const auto implicit_input_node_args = node->ImplicitInputDefs(); @@ -489,18 +713,34 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Recursively check subgraphs. size_t api_num_node_subgraphs = 0; ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumSubgraphs(api_node, &api_num_node_subgraphs)); + ASSERT_EQ(api_num_node_subgraphs, node_subgraphs_map.size()); std::vector api_node_subgraphs(api_num_node_subgraphs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size())); - - for (size_t subgraph_idx = 0; subgraph_idx < node_subgraphs.size(); subgraph_idx++) { - auto subgraph_viewer = std::make_unique(*node_subgraphs[subgraph_idx]); - const OrtGraph* api_subgraph = api_node_subgraphs[subgraph_idx]; + std::vector api_subgraph_attr_names(api_num_node_subgraphs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size(), + api_subgraph_attr_names.data())); + + for (const auto& [attr_name, subgraph] : node_subgraphs_map) { + // find index of this subgraph. + size_t api_subgraph_idx = api_num_node_subgraphs; + for (size_t subgraph_idx = 0; subgraph_idx < api_num_node_subgraphs; subgraph_idx++) { + if (api_subgraph_attr_names[subgraph_idx] == attr_name) { + api_subgraph_idx = subgraph_idx; + break; + } + } + ASSERT_NE(api_subgraph_idx, api_num_node_subgraphs); + // Recursively check the subgraph + auto subgraph_viewer = std::make_unique(*subgraph); + const OrtGraph* api_subgraph = api_node_subgraphs[api_subgraph_idx]; CheckGraphCApi(*subgraph_viewer, *api_subgraph); } } } + + // Check creating an OrtGraph from a subset of nodes in an OrtGraph + Check_Graph_GetSubgraph(api_graph); } } // namespace test diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc index b7743e65061de..3b3bc4c6da911 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc @@ -30,6 +30,7 @@ std::unique_ptr TestGraph::Load(const ORTCHAR_T* model_path) { const OrtGraph& TestGraph::GetOrtGraph() const { return *api_graph; } const GraphViewer& TestGraph::GetGraphViewer() const { return graph_viewer; } +const Model& TestGraph::GetModel() const { return *model; } static Status GetInputIndices(const Node& consumer_node, const std::string& name, /*out*/ std::vector& indices) { diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.h b/onnxruntime/test/ep_graph/test_ep_graph_utils.h index b0ed825f21d71..2ce107cf734c6 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.h +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.h @@ -28,6 +28,7 @@ class TestGraph { static std::unique_ptr Load(const ORTCHAR_T* model_path); const OrtGraph& GetOrtGraph() const; const GraphViewer& GetGraphViewer() const; + const Model& GetModel() const; private: std::shared_ptr model; diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 18bc9cf05b36d..4c5dcd2bd7580 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -36,7 +36,7 @@ struct TestOrtEp : ::OrtEp, ApiPtrs { // Individual tests should fill out the other function pointers as needed. } - static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) { + static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) noexcept { constexpr const char* ep_name = "TestOrtEp"; return ep_name; } @@ -50,7 +50,7 @@ struct TestOrtEpFactory : ::OrtEpFactory { ReleaseEp = ReleaseEpImpl; } - static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) { + static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { delete static_cast(ep); } }; @@ -125,7 +125,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { } { - auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { + auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) noexcept -> ::OrtStatus* { *preferred_data_layout = OrtEpDataLayout::OrtEpDataLayout_NCHW; return nullptr; }; @@ -135,7 +135,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { #if !defined(ORT_NO_EXCEPTIONS) { - auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { + auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) noexcept -> ::OrtStatus* { *preferred_data_layout = static_cast(-1); return nullptr; }; @@ -144,7 +144,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { } { - auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) -> ::OrtStatus* { + auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) noexcept -> ::OrtStatus* { auto* test_ort_ep = static_cast(this_ptr); return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "I can't decide what data layout I prefer."); }; @@ -167,7 +167,7 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { const char* /*node_domain*/, const char* node_op_type, OrtEpDataLayout target_data_layout, - int* should_convert) -> ::OrtStatus* { + int* should_convert) noexcept -> ::OrtStatus* { EXPECT_EQ(target_data_layout, OrtEpDataLayout::OrtEpDataLayout_NHWC); if (node_op_type == std::string_view{"Conv"}) { @@ -201,7 +201,7 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { const char* /*node_domain*/, const char* /*node_op_type*/, OrtEpDataLayout /*target_data_layout*/, - int* /*should_convert*/) -> ::OrtStatus* { + int* /*should_convert*/) noexcept -> ::OrtStatus* { auto* test_ort_ep = static_cast(this_ptr); return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "To convert to NHWC or not to convert to NHWC..."); diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp index 65822eb294d7d..ea36383f70621 100644 --- a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp +++ b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp @@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) { std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory // warming up run - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); } free(ptr.underlying_buffer); diff --git a/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp new file mode 100644 index 0000000000000..b994981364947 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" + +template +class MlasDequantizeLinearTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + void GenerateReference(const QuantInt* Input, float* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } + } + + void Test(size_t N) { + QuantInt* Input = BufferInput.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + + std::uniform_real_distribution min_gen(-10.f, -10e-3f); + float MinimumValue = min_gen(generator); + + std::uniform_real_distribution max_gen(10e-3f, 10.f); + float MaximumValue = max_gen(generator); + + float Scale = (MaximumValue - MinimumValue) / 512.f; + + std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), + std::numeric_limits::max()); + QuantInt ZeroPoint = static_cast(zp_distribution(generator)); + + for (size_t n = 0; n < N; n++) { + Input[n] = static_cast(zp_distribution(generator)); + } + + GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); + MlasDequantizeLinear(Input, Output, N, Scale, ZeroPoint); + + for (size_t n = 0; n < N; n++) { + ASSERT_EQ(Output[n], OutputReference[n]) << ", size=" << N << ", index=" << n; + } + } + + public: + static const char* GetTestSuiteName() { + if constexpr (std::is_same_v) { + return "DequantizeLinearS8"; + } else { + return "DequantizeLinearU8"; + } + } + + void ExecuteShort(void) override { + for (size_t n = 1; n <= 512; n++) { + Test(n); + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + } + return count; +}); diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp index 041b6c61cd5bf..4d7a45143b311 100644 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp @@ -152,7 +152,7 @@ class MlasSoftmaxTest : public MlasTestBase { } void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 1e-6f; @@ -206,7 +206,7 @@ class MlasSoftmaxTest : public MlasTestBase { InputReference[nd] = Input[nd].ToFloat(); } - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 5e-3f; diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 649c9af7cc80b..215203b31f49c 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -61,7 +61,8 @@ TEST(SoftmaxOperator, webgpu_nan) { test.AddOutput("Y", dimensions, expected_result); // explicitly disable for EPs that do not handle NaN - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider, kCoreMLExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCpuExecutionProvider, kCoreMLExecutionProvider, kDmlExecutionProvider}); } #endif diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 4e7a6356a5129..8fdbf0060eaa0 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -33,6 +33,32 @@ TEST(DequantizeLinearOpTest, Int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// scalar zero & scale with uint8 (large enough input to execute MLAS vectorized loop) +TEST(DequantizeLinearOpTest, Uint8_Large) { + OpTester test("DequantizeLinear", 10); + std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs + test.AddInput("x", dims, std::vector(1039, 1)); + test.AddInput("x_scale", {}, {1.0f}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOutput("y", dims, std::vector(1039, 0.0f)); + // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. + // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); +} + +// scalar zero & scale with int8 (large enough input to execute MLAS vectorized loop) +TEST(DequantizeLinearOpTest, Int8_Large) { + OpTester test("DequantizeLinear", 10); + std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs + test.AddInput("x", dims, std::vector(1039, 1)); + test.AddInput("x_scale", {}, {1.0f}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOutput("y", dims, std::vector(1039, 0.0f)); + // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. + // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); +} + // scalar zero & scale with int4 TEST(DequantizeLinearOpTest, Int4) { OpTester test("DequantizeLinear", 21); diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc index 895c8ab3e53e4..e6d113e1e4dca 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc @@ -235,5 +235,16 @@ TEST(ScatterNDOpTest, ScatterND_18_max) { test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } +// Test for ScatterND with empty indices - output should be same as input +TEST(ScatterNDOpTest, ScatterND_empty_indices) { + // Test with float data type and minimal empty case + OpTester test1("ScatterND", 11); + test1.AddInput("data", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + test1.AddInput("indices", {0, 1}, {}); // Empty indices tensor - no indices to process + test1.AddInput("updates", {0, 3}, {}); // Empty updates tensor + test1.AddOutput("output", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); // Same as input + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 4febfe7ba836d..739e39a6975e2 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -509,6 +509,11 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB Ort::ModelCompilationOptions compile_options(*ort_env, session_options); compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); + std::string target_dir = "./testdata/"; + std::string model_name = "test_model_in_mem.onnx"; + auto pos = model_name.rfind(".onnx"); + std::string bin_file_name = model_name.substr(0, pos) + "_qnn.bin"; + compile_options.SetEpContextBinaryInformation(ToWideString(target_dir).c_str(), ToWideString(model_name).c_str()); compile_options.SetEpContextEmbedMode(false); // Compile the model. @@ -519,12 +524,18 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB ASSERT_TRUE(output_model_buffer != nullptr); ASSERT_TRUE(output_model_buffer_size > 0); + ASSERT_TRUE(std::filesystem::exists(target_dir + bin_file_name)) << "expected context binary file should exist"; + // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + // Add session option "ep.context_file_path" so that the session can use it to locate the [model_name]_qnn.bin file + std::string ctx_model = target_dir + model_name; + session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ctx_model.c_str()); // Should be able to create a session with the compiled model and the original session options. EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options))); + std::filesystem::remove(target_dir + bin_file_name); allocator.Free(output_model_buffer); } } @@ -1649,7 +1660,6 @@ static void DumpModelWithSharedCtx(ProviderOptions provider_options, Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); } -#if defined(__aarch64__) || defined(_M_ARM64) static void GetModelInputNames(const std::string& model_path, std::vector& input_names, std::vector& output_names, @@ -1669,7 +1679,6 @@ static void GetModelInputNames(const std::string& model_path, output_names.push_back(output->Name()); } } -#endif // 1. Create 2 QDQ models // 2. Initialize 2 Ort sessions which share the same QNN EP from these 2 QDQ models @@ -1994,6 +2003,73 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) { }); } } + +TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + Ort::SessionOptions so; + so.AppendExecutionProvider("QNN", provider_options); + so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx"), so); + + std::vector input_names; + std::vector output_names; + GetModelInputNames("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx", input_names, output_names, + DefaultLoggingManager().DefaultLogger()); + + // Run sessions + // prepare input + std::vector input_dim{3, 4}; + std::vector input_value(3 * 4, 0.0f); + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + std::vector ort_inputs; + std::vector input_names_c; + for (size_t i = 0; i < input_names.size(); ++i) { + auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(), + input_dim.data(), input_dim.size()); + ort_inputs.push_back(std::move(input_tensor)); + input_names_c.push_back(input_names[i].c_str()); + } + std::vector output_names_c; + for (size_t i = 0; i < output_names.size(); ++i) { + output_names_c.push_back(output_names[i].c_str()); + } + + auto ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + const char* const workload_type[] = {"ep.dynamic.workload_type"}; + const char* const efficient_type[] = {"Efficient"}; + const char* const default_type[] = {"Default"}; + + // Test Efficient & Default options + session.SetEpDynamicOptions(workload_type, efficient_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + session.SetEpDynamicOptions(workload_type, default_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + // Test invalid EP dynamic option and invalid workload type + const char* const dne[] = {"DNE"}; + try { + session.SetEpDynamicOptions(workload_type, dne, 1); + FAIL() << "Expected exception to be thrown for workload type DNE but was set successfully"; + } catch (const std::exception& e) { + EXPECT_STREQ("Invalid EP Workload Type.", e.what()); + } + + try { + session.SetEpDynamicOptions(dne, efficient_type, 1); + FAIL() << "Expected exception to be thrown for dynamic option DNE but was set successfully"; + } catch (const std::exception& e) { + EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); + } +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 85f8250f70fc5..4c0a53e83e274 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1254,6 +1254,38 @@ TEST_F(QnnHTPBackendTests, GridSample_U16_Nearest) { true); } +// Test QDQ GridSample with `linear` mode on opset 20+. +TEST_F(QnnHTPBackendTests, GridSample_Linear_ZerosPadding) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "zeros")}, + /*opset_version=*/20, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, GridSample_Linear_AlignCorners_BorderPadding) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("align_corners", static_cast(1)), + utils::MakeAttribute("mode", "linear"), + utils::MakeAttribute("padding_mode", "border")}, + /*opset_version=*/20, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, GridSample_Linear_ReflectionPadding_U16) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "reflection")}, + /*opset_version=*/21, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*op_domain=*/kOnnxDomain, + /*use_contrib_qdq=*/true); +} + // Test QDQ GridSample with reflection padding mode // Inaccuracy detected for output 'output', element 2. // Output quant params: scale=0.024269860237836838, zero_point=0. diff --git a/onnxruntime/test/python/quantization/test_op_topk.py b/onnxruntime/test/python/quantization/test_op_topk.py new file mode 100644 index 0000000000000..1fdd0c987d1e8 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_topk.py @@ -0,0 +1,103 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest + +import numpy as np +from onnx import TensorProto, helper, save +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + + +class TestTopKModel(unittest.TestCase): + @staticmethod + def construct_model(model_path, input_shape, axis_attr, k): + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape) + k_tensor = helper.make_tensor("k", TensorProto.INT64, [1], [k]) + output_shape = input_shape[:] + output_shape[axis_attr] = k + output_values = helper.make_tensor_value_info("values", TensorProto.FLOAT, [1, k]) + output_indices = helper.make_tensor_value_info("indices", TensorProto.INT64, [1, k]) + + node = helper.make_node( + "TopK", inputs=["input", "k"], outputs=["values", "indices"], name="topk_node", axis=axis_attr + ) + + graph = helper.make_graph( + [node], + "quant_topk_op_test", + [input_tensor], + [output_values, output_indices], + initializer=[k_tensor], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 16), helper.make_opsetid("com.microsoft", 1)] + ) + save(model, model_path) + + def quantize_topk_test(self, activation_type, weight_type, extra_options={}): # noqa: B006 + model_fp32_path = "topk_fp32.onnx" + input_shape = [1, 10] + axis = 1 + k = 3 + self.construct_model(model_fp32_path, input_shape, axis, k) + + input_data_list = [ + {"input": np.array([[1.8, 2.5, -5.9, 5.2, 4.1, 7.3, 0.2, -0.5, 0.845, 3.9]], dtype=np.float32)} + ] + data_reader = TestDataFeeds(input_data_list) + + activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_qdq_path = f"topk_{activation_type_str}{weight_type_str}_{'QNoInCk' if extra_options['ForceQuantizeNoInputCheck'] else 'NoQNoInCk'}_qdq.onnx" + + # Verify QDQ mode + data_reader.rewind() + quantize_static( + model_fp32_path, + model_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = ( + { + "TopK": 1, + "QuantizeLinear": 2, + "DequantizeLinear": 2, + } + if extra_options["ForceQuantizeNoInputCheck"] + else { + "TopK": 1, + "QuantizeLinear": 0, + "DequantizeLinear": 0, + } + ) + check_op_type_count(self, model_qdq_path, **qdqnode_counts) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + check_qtype_by_node_type(self, model_qdq_path, qnode_io_qtypes) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + + def test_quantize_topk_u8u8(self): + self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": True}) + + def test_quantize_topk_u8u8_no_force_quantize_no_input_check(self): + self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": False}) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx b/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx new file mode 100644 index 0000000000000000000000000000000000000000..34cf26c13d3fc98f8a97aa3f9999e3d99e5bf847 GIT binary patch literal 7729 zcmeI1%Wl&^6hLEFY1~UGX5^MANT3E)mAtT%B7%r2p=KFTgkaarXp*?Ki9O@mV_qv( zhz&fg>%mqFw?GGRPD z6iMVIX>Gab?Cdy=_J>{gs6jd4aVF0ilS#>)A&nF9(&+^(g-Xct2SVJCyL*EHZBmg* zwT3ooH(m^b_z8RKB~O)b+Nf__>R@5;j>$l9`zBPpI1NI<*S~z*et4p3?dyFJIZ@D0 zL@Ev?eAZxw2H4n>(o;?dP8;-i_=>*vf+JsoHQApVTY?g-DHp~oB9;zG)gAfdKKD|e z#U6chVg0q=WYkyAUu+XrcotFLV}rD+pJ@7|twWeA6wFcF+wFZO_p^{TTP<>@FhB(@ z534&KIuGLd%<=kiF%Q0K@D~X)12;dDFf~M$nym-5TbFW2RjNA*fPYCUfrtg19wjXH zZQrm=tuv*&`>a%Y|9FwNO><3W;Eoh5_OidP8dl-WWU{-btBZ66Wi1vBj3>qu89)ZE zG6VLHEp@u=s$U>nhuiw&DIl29N<{02x3AkO5=>89)Y*0b~FfKn9QjWB?hk zKk4}&u9;Q5?oaK1W8~o8xByFPP&G7SL4}liO!j>!lcm%<2Hmg@?oZV=H{q_Defwgz XZm5cGv7%^tn=mTw{Yh>|H`jgvT1N^V literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 461c243b82212..7f2134b2cda4f 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -13,6 +13,7 @@ import random import unittest from dataclasses import dataclass +from enum import Enum import numpy import torch @@ -38,11 +39,17 @@ ATOL = None -class Formats: +class Formats(Enum): BSNH = 0 BNSH = 1 +class QKOutputType(Enum): + NO_OUTPUT = 0 + BEFORE_SOFTMAX = 1 + AFTER_SOFTMAX = 2 + + @dataclass class Config: batch_size: int = 0 @@ -54,6 +61,8 @@ class Config: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False + has_head_sink: bool = False + qk_output: QKOutputType = QKOutputType.NO_OUTPUT @dataclass @@ -67,6 +76,8 @@ class PromptConfig: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False + has_head_sink: bool = False + qk_output: QKOutputType = QKOutputType.NO_OUTPUT # LLaMA Microsoft model @@ -151,6 +162,15 @@ def create_group_query_attention_graph_prompt( ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length + + output_names = [ + "output", + "present_key", + "present_value", + ] + if config.qk_output != QKOutputType.NO_OUTPUT: + output_names.append("output_qk") + nodes = [ helper.make_node( "GroupQueryAttention", @@ -166,8 +186,9 @@ def create_group_query_attention_graph_prompt( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", ], - ["output", "present_key", "present_value"], + output_names, "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, @@ -176,6 +197,7 @@ def create_group_query_attention_graph_prompt( rotary_interleaved=rotary_interleaved, softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, + qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -289,6 +311,15 @@ def create_group_query_attention_graph_prompt( ), ] + if config.has_head_sink: + graph_input += [ + helper.make_tensor_value_info( + "head_sink", + ort_type, + [config.num_heads], + ), + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -337,6 +368,15 @@ def create_group_query_attention_graph_prompt( ), ] + if config.qk_output != QKOutputType.NO_OUTPUT: + graph_output += [ + helper.make_tensor_value_info( + "output_qk", + ort_type, + [config.batch_size, config.num_heads, config.kv_sequence_length, config.kv_sequence_length], + ), + ] + graph = helper.make_graph( nodes, "GroupQueryAttention_Graph", @@ -365,6 +405,15 @@ def create_group_query_attention_graph_past( present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length ) + + output_names = [ + "output", + "present_key", + "present_value", + ] + if config.qk_output != QKOutputType.NO_OUTPUT: + output_names.append("output_qk") + nodes = [ helper.make_node( "GroupQueryAttention", @@ -380,8 +429,9 @@ def create_group_query_attention_graph_past( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", ], - ["output", "present_key", "present_value"], + output_names, "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, @@ -390,6 +440,7 @@ def create_group_query_attention_graph_past( rotary_interleaved=rotary_interleaved, softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, + qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -441,6 +492,7 @@ def create_group_query_attention_graph_past( [1], ), ] + if not packed: graph_input += [ helper.make_tensor_value_info( @@ -462,6 +514,7 @@ def create_group_query_attention_graph_past( ], ), ] + if rotary: graph_input += [ helper.make_tensor_value_info( @@ -498,6 +551,15 @@ def create_group_query_attention_graph_past( ), ] + if config.has_head_sink: + graph_input += [ + helper.make_tensor_value_info( + "head_sink", + ort_type, + [config.num_heads], + ), + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -526,6 +588,15 @@ def create_group_query_attention_graph_past( ), ] + if config.qk_output != QKOutputType.NO_OUTPUT: + graph_output += [ + helper.make_tensor_value_info( + "output_qk", + ort_type, + [config.batch_size, config.num_heads, config.sequence_length, present_kv_seqlen], + ), + ] + graph = helper.make_graph( nodes, "GroupQueryAttention_Graph", @@ -552,17 +623,17 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) + q: (batch_size, seqlen_q, num_heads, d) + k: (batch_size, seqlen_k, num_heads_k, d) + v: (batch_size, seqlen_k, num_heads_k, d) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) + batch_size, seqlen_q, num_heads, d = q.shape + _, seqlen_k, num_heads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, num_heads_k, d) + assert v.shape == (batch_size, seqlen_k, num_heads_k, d) if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) @@ -593,7 +664,7 @@ def output_pad_fn(output_unpad): if qkvpacked: assert (query_padding_mask == key_padding_mask).all() - assert nheads == nheads_k + assert num_heads == num_heads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: @@ -714,6 +785,8 @@ def gqa_prompt_func( seqlens_k=None, position_ids=None, attention_bias=None, + head_sink=None, + output_qk=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, @@ -746,9 +819,18 @@ def gqa_prompt_func( if config.has_attention_bias: assert attention_bias is not None + if config.qk_output != QKOutputType.NO_OUTPUT: + assert output_qk is not None + if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + ort_outputs = {} + if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -757,10 +839,6 @@ def gqa_prompt_func( "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } - - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -797,25 +875,18 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() @@ -836,11 +907,26 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) + + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) + io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) + + ort_session.run_with_iobinding(io_binding) + + out_qk = None + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_output, present_k, present_v, out_qk = io_binding.copy_outputs_to_cpu() + else: ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + + return output, present_k, present_v, out_qk def gqa_past_func( @@ -855,6 +941,8 @@ def gqa_past_func( seqlens_k=None, position_ids=None, attention_bias=None, + head_sink=None, + output_qk=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, @@ -887,9 +975,18 @@ def gqa_past_func( if config.has_attention_bias: assert attention_bias is not None + if config.qk_output != QKOutputType.NO_OUTPUT: + assert output_qk is not None + if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + ort_outputs = {} + if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -901,9 +998,6 @@ def gqa_past_func( .cpu() .numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -940,11 +1034,6 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -958,9 +1047,6 @@ def gqa_past_func( .cpu() .numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -988,11 +1074,26 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) + + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) + io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) + + ort_session.run_with_iobinding(io_binding) + + out_qk = None + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_output, present_k, present_v, out_qk = io_binding.copy_outputs_to_cpu() + else: ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + + return output, present_k, present_v, out_qk def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): @@ -1025,11 +1126,28 @@ def construct_local_mask( ) -def smooth_softmax_ref(x): - x_max = x.amax(axis=-1, keepdim=True) - x_max = torch.maximum(x_max, torch.zeros_like(x_max)) - w = torch.exp(x - x_max) - return w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-x_max)) +def smooth_softmax_ref(x, head_sink): + """ + Arguments: + x: (batch_size, num_heads, seqlen_q, seqlen_k) + head_sink: (num_heads) or None + Output: + y: (batch_size, num_heads, seqlen_q, seqlen_k) + """ + assert len(x.shape) == 4 + b, n, s, t = x.shape + + if head_sink is not None: + assert len(head_sink.shape) == 1 + assert head_sink.shape[0] == x.shape[1] + sink = head_sink.reshape(1, n, 1, 1).expand(b, -1, s, -1) + else: + sink = torch.zeros(b, n, s, 1, dtype=x.dtype) + + y = torch.cat([x, sink], dim=-1) + y = torch.softmax(y, dim=-1) + y = y[..., :-1] + return y def attention_ref( @@ -1046,16 +1164,17 @@ def attention_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, + head_sink=None, ): """ Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) + q: (batch_size, seqlen_q, num_heads, head_dim) + k: (batch_size, seqlen_k, num_heads_k, head_dim) + v: (batch_size, seqlen_k, num_heads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + dropout_mask: (batch_size, num_heads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast @@ -1064,8 +1183,10 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. use_smooth_softmax: whether use smooth softmax or not + head_sink: (num_heads) or None Output: output: (batch_size, seqlen_q, nheads, head_dim) + masked_scores: (batch_size, nheads, seqlen_q, seqlen_k), before softmax attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: @@ -1085,8 +1206,10 @@ def attention_ref( scores = scores / softcap scores = scores.tanh() scores = scores * softcap + masked_scores = scores.clone() if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + masked_scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, @@ -1096,10 +1219,11 @@ def attention_ref( key_padding_mask, q.device, ) + masked_scores.masked_fill_(local_mask, 0.0) scores.masked_fill_(local_mask, float("-inf")) - if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + if use_smooth_softmax or (head_sink is not None): + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) @@ -1121,7 +1245,7 @@ def attention_ref( if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + return output.to(dtype=dtype_og), masked_scores.to(dtype=dtype_og), attention.to(dtype=dtype_og) def attention_qkvpacked_ref( @@ -1133,6 +1257,7 @@ def attention_qkvpacked_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, + head_sink=None, ): return attention_ref( qkv[:, :, 0], @@ -1146,6 +1271,7 @@ def attention_qkvpacked_ref( causal=causal, reorder_ops=reorder_ops, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) @@ -1186,6 +1312,10 @@ def get_custom_position_ids(batch_size, sequence_length, seqlens_k=None, past=Fa return position_ids +def get_custom_head_sink(num_heads, torch_type=torch.float16): + return torch.rand(num_heads, dtype=torch_type) + + def parity_check_gqa_prompt( config, torch_type, @@ -1248,6 +1378,8 @@ def parity_check_gqa_prompt( requires_grad=False, ) + head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None + window_size = (-1, -1) left_window_size = -1 if local: @@ -1305,6 +1437,20 @@ def parity_check_gqa_prompt( else None ) + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.kv_sequence_length, + config.q_sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) @@ -1315,7 +1461,7 @@ def parity_check_gqa_prompt( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1327,6 +1473,7 @@ def parity_check_gqa_prompt( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1337,7 +1484,7 @@ def parity_check_gqa_prompt( # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( packed_qkv, k, v, @@ -1349,6 +1496,8 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, + head_sink, + output_qk, left_window_size, past_format, True, @@ -1359,7 +1508,7 @@ def parity_check_gqa_prompt( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( q, k, v, @@ -1371,6 +1520,8 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, + head_sink, + output_qk, left_window_size, past_format, True, @@ -1384,6 +1535,22 @@ def parity_check_gqa_prompt( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1425,6 +1592,8 @@ def parity_check_gqa_prompt( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1531,12 +1700,28 @@ def parity_check_gqa_prompt_no_buff( else None ) + head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None + + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.kv_sequence_length, + config.q_sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1548,6 +1733,7 @@ def parity_check_gqa_prompt_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1558,7 +1744,7 @@ def parity_check_gqa_prompt_no_buff( # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( packed_qkv, None, None, @@ -1570,6 +1756,8 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, + head_sink, + output_qk, left_window_size, past_format, False, @@ -1580,7 +1768,7 @@ def parity_check_gqa_prompt_no_buff( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( q, None, None, @@ -1592,6 +1780,8 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, + head_sink, + output_qk, left_window_size, past_format, False, @@ -1605,6 +1795,22 @@ def parity_check_gqa_prompt_no_buff( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1646,6 +1852,8 @@ def parity_check_gqa_prompt_no_buff( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1759,6 +1967,8 @@ def parity_check_gqa_past( cos, sin = None, None q_ro, k_ro = q, new_k + head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None + arange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -1769,7 +1979,7 @@ def parity_check_gqa_past( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1781,6 +1991,7 @@ def parity_check_gqa_past( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1807,10 +2018,24 @@ def parity_check_gqa_past( else None ) + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.sequence_length, + config.kv_sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( packed_qkv, k, v, @@ -1822,6 +2047,8 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, + head_sink, + output_qk, past_format, True, left_window_size, @@ -1832,7 +2059,7 @@ def parity_check_gqa_past( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( q, k, v, @@ -1844,6 +2071,8 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, + head_sink, + output_qk, past_format, True, left_window_size, @@ -1857,6 +2086,22 @@ def parity_check_gqa_past( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + 1 + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1882,6 +2127,8 @@ def parity_check_gqa_past( softcap, " smooth_softmax:", use_smooth_softmax, + " head_sink:", + config.has_head_sink, " B:", config.batch_size, " S:", @@ -1898,6 +2145,8 @@ def parity_check_gqa_past( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2017,6 +2266,8 @@ def parity_check_gqa_past_no_buff( cos, sin = None, None q_ro, k_ro = q, new_k + head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -2027,7 +2278,7 @@ def parity_check_gqa_past_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -2039,6 +2290,7 @@ def parity_check_gqa_past_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -2065,10 +2317,24 @@ def parity_check_gqa_past_no_buff( else None ) + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.sequence_length, + config.kv_sequence_length + config.sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( packed_qkv, k, v, @@ -2080,6 +2346,8 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, + head_sink, + output_qk, past_format, False, window_size=left_window_size, @@ -2090,7 +2358,7 @@ def parity_check_gqa_past_no_buff( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( q, k, v, @@ -2102,6 +2370,8 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, + head_sink, + output_qk, past_format, False, window_size=left_window_size, @@ -2115,6 +2385,22 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + 1 + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Compare results all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET @@ -2134,6 +2420,8 @@ def parity_check_gqa_past_no_buff( softcap, " smooth_softmax:", use_smooth_softmax, + " head_sink:", + config.has_head_sink, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -2152,6 +2440,8 @@ def parity_check_gqa_past_no_buff( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2180,7 +2470,16 @@ def setUp(self): ] def run_test_config( - self, test_func, config_class, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, additional_params=None + self, + test_func, + config_class, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, + additional_params=None, ): if additional_params is None: additional_params = {} @@ -2202,33 +2501,59 @@ def run_test_config( for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: for has_pos, has_attn in pos_ids_attn_bias: - if config_class == PromptConfig: - config = config_class( - b, s, s2, s + s2 + 8, n, n2, h, has_pos, has_attn - ) - else: # Config - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = config_class(b, s, s2, sp, n, n2, h, has_pos, has_attn) - - params = { - "config": config, - "torch_type": precision["torch_type"], - "numpy_type": precision["numpy_type"], - "ort_type": precision["ort_type"], - "rtol": precision["rtol"], - "atol": precision["atol"], - "local": local, - "past_format": Formats.BNSH, - "rotary": rotary, - "rotary_interleaved": rotary_interleaved, - "packed": packed, - "softcap": softcap, - "use_smooth_softmax": use_smooth_softmax, - } - params.update(additional_params) - - all_close = test_func(**params) - self.assertTrue(all_close) + for head_sink in [False, True]: + if use_smooth_softmax and head_sink: + continue + for output_qk in qk_output: + if config_class == PromptConfig: + config = config_class( + b, + s, + s2, + s + s2 + 8, + n, + n2, + h, + has_pos, + has_attn, + head_sink, + output_qk, + ) + else: # Config + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = config_class( + b, + s, + s2, + sp, + n, + n2, + h, + has_pos, + has_attn, + head_sink, + output_qk, + ) + + params = { + "config": config, + "torch_type": precision["torch_type"], + "numpy_type": precision["numpy_type"], + "ort_type": precision["ort_type"], + "rtol": precision["rtol"], + "atol": precision["atol"], + "local": local, + "past_format": Formats.BNSH, + "rotary": rotary, + "rotary_interleaved": rotary_interleaved, + "packed": packed, + "softcap": softcap, + "use_smooth_softmax": use_smooth_softmax, + } + params.update(additional_params) + + all_close = test_func(**params) + self.assertTrue(all_close) def test_gqa_no_past(self): print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") @@ -2245,12 +2570,33 @@ def test_gqa_no_past(self): ) num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + qk_output = ( + [QKOutputType.NO_OUTPUT] + if pipeline_mode + else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] + ) # Test with buffer - self.run_test_config(parity_check_gqa_prompt, PromptConfig, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) + self.run_test_config( + parity_check_gqa_prompt, + PromptConfig, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, + ) # Test without buffer self.run_test_config( - parity_check_gqa_prompt_no_buff, PromptConfig, batches, seqs, num_h, h_sizes, pos_ids_attn_bias + parity_check_gqa_prompt_no_buff, + PromptConfig, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, ) def test_gqa_past(self): @@ -2268,11 +2614,25 @@ def test_gqa_past(self): ) num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + qk_output = ( + [QKOutputType.NO_OUTPUT] + if pipeline_mode + else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] + ) # Test with buffer - self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) + self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, qk_output) # Test without buffer - self.run_test_config(parity_check_gqa_past_no_buff, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) + self.run_test_config( + parity_check_gqa_past_no_buff, + Config, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, + ) def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") @@ -2287,6 +2647,7 @@ def test_gqa_interactive_one_batch(self): if pipeline_mode else [(False, False), (True, True), (False, True), (True, False)] ) + qk_output = [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] @@ -2299,6 +2660,7 @@ def test_gqa_interactive_one_batch(self): num_h, h_sizes, pos_ids_attn_bias, + qk_output, additional_params={"softcap": 0.0, "use_smooth_softmax": False}, ) self.run_test_config( @@ -2309,6 +2671,7 @@ def test_gqa_interactive_one_batch(self): num_h, h_sizes, pos_ids_attn_bias, + qk_output, additional_params={"softcap": 0.0, "use_smooth_softmax": False}, ) diff --git a/onnxruntime/test/python/transformers/test_gqa_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py index 2f5b638a57d0c..79976a92e54bf 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cuda.py +++ b/onnxruntime/test/python/transformers/test_gqa_cuda.py @@ -782,7 +782,8 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + head_sink = None + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py index 410860a324a9d..ca5c9c2ce133f 100644 --- a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py +++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py @@ -401,7 +401,8 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + head_sink = None + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_phi_vision.py b/onnxruntime/test/python/transformers/test_phi_vision.py index 67f89e633a146..d276366706af9 100644 --- a/onnxruntime/test/python/transformers/test_phi_vision.py +++ b/onnxruntime/test/python/transformers/test_phi_vision.py @@ -149,7 +149,7 @@ def __init__(self): self.attn = PhiVCLIPAttention() self.ln = torch.nn.LayerNorm(20, eps=1e-05) - def forward(self, x): + def forward(self, x, attention_mask=None): # SkipLayerNorm ------+ # | | # Attention | @@ -163,8 +163,7 @@ def forward(self, x): x = self.ln(x) residual = x - # Attention + MatMul - x = self.attn(x) + x = self.attn(x, attention_mask=attention_mask) # SkipLayerNorm x = residual + x @@ -194,14 +193,31 @@ def verify_fusion(self, optimized_model, expected_model_filename): ) def export(self, model, inputs): - torch.onnx.export( - model, - args=inputs, - f=os.path.join(os.path.dirname(__file__), "export.onnx"), - export_params=True, - opset_version=14, - do_constant_folding=True, - ) + path = os.path.join(os.path.dirname(__file__), "export.onnx") + + if len(inputs) == 2: + torch.onnx.export( + model, + args=inputs, + f=path, + export_params=True, + opset_version=14, + do_constant_folding=True, + input_names=["input", "attention_mask"], + dynamic_axes={ + "input": {0: "batch", 1: "seq"}, + "attention_mask": {0: "batch", 2: "seq", 3: "seq"}, + }, + ) + else: + torch.onnx.export( + model, + args=inputs, + f=path, + export_params=True, + opset_version=14, + do_constant_folding=True, + ) def tearDown(self): path = os.path.join(os.path.dirname(__file__), "export.onnx") @@ -249,6 +265,38 @@ def test_phi_vision_attention(self): ) self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-attention.onnx") + def test_phi_vision_attention_with_mask(self): + model = PhiVCLIPAttentionAndLayerNorm() + + batch, seq_len, dim = 1, 2, 20 + mask = torch.zeros(batch, 1, seq_len, seq_len) + mask[:, 1:] = float("-inf") + + inputs = (torch.randn(batch, seq_len, dim), mask) + self.export(model, inputs) + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) + options = FusionOptions("clip") + optimized_model = optimize_model( + original_model, + model_type="clip", + num_heads=2, + hidden_size=20, + optimization_options=options, + opt_level=0, + use_gpu=True, + ) + self.verify_fusion(optimized_model, "phi-4-v-instruct-vision-attention.onnx") + + graph = optimized_model.model.graph + attention_node = next((n for n in graph.node if n.name == "Attention_0"), None) + self.assertIsNotNone(attention_node, "Could not find the Attention fused node") + attr_names = [attr.name for attr in attention_node.attribute] + self.assertNotIn( + "unidirectional", + attr_names, + f"The attention node should not have a 'unidirectional' attribute: {attr_names}", + ) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d036541a70aa087f6007ec7261f5f1115b0e22f2 GIT binary patch literal 1892 zcmc&#&2G~`5Y8GWaV8~=whcrEAaP&c$9+r)V_e7+CMS24QOn};JS@5uuwhojaXfGV8^v_J5P zdou3)00^KKITpQ`lc^PqiAS-P$a?eD%nd@~hP}}dH(80r++4G?cA)rhT<3}SzV zo0H)NM;CKS-}5BRvOL4jGD9eH!WEX$*~psB!&{MZ1XACWRiwTs0x4Tok{~7J8<3Kg zyCC)P1w$%{yoS^cLrIz#W*xi{bu2O*#%bu4m&0LSw8Ff{j<~^0?WofhLE6E5aOx9p z+-hn{z1&rhvR6xj#l0Rpft7%`1{)f}8YmiKUuB|a&*yC0BB8Y#SE$(ftwJ>%Q#WDR zcU7`1t}VgNkwx6ZGU0g_>iSUC`&kVX?S*?o`uvGX(i5vu@IN| z?*bMeM{k4CleJI+1N)Ij+#zSHS&GlH!_In#AF`<{cM)O@maxVVTc1$c`~O>;pxRP( zIXcxvj{r1AK$R14@*wLUUe<5PZY?Vr@Ax{%IKP6((s~;_f@}%gISCP9`Mn8CLM$zz ztjLTTtJ|gos#d`TJ`=yt>P%cC=%)K$iEO=@Z0!RYCTsuD^;qlE&7WDoU|`u|{wi#u zIc1`bUgE^=dGRX1OxccXz73K+AWBcXbEQ8{)4^}*ur}5cKJ0ex&TT6IS7u(=7R%@O gpK*_GjBuRe!e9%~W$t+nclOVTCER-|6zZFQ0aGjCcK`qY literal 0 HcmV?d00001 diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 450b955f161af..f02e3e8058c29 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -1,6 +1,6 @@ # This file is auto updated by dependabot # When any package below is changed, you shall run "lintrunner init" again. lintrunner==0.12.7 -lintrunner-adapters==0.12.4 -ruff==0.12.2 -clang-format==20.1.7 +lintrunner-adapters==0.12.5 +ruff==0.12.3 +clang-format==20.1.8 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index f6e37d33b2414..893f3c80fa4b8 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -284,6 +284,8 @@ def generate_vcpkg_install_options(build_dir, args): vcpkg_install_options.append("--x-feature=vsinpu-ep") if args.use_webgpu: vcpkg_install_options.append("--x-feature=webgpu-ep") + if args.wgsl_template == "dynamic": + vcpkg_install_options.append("--x-feature=webgpu-ep-wgsl-template-dynamic") if args.use_webnn: vcpkg_install_options.append("--x-feature=webnn-ep") if args.use_xnnpack: @@ -470,6 +472,7 @@ def generate_build_tree( else "OFF" ), "-Donnxruntime_REDUCED_OPS_BUILD=" + ("ON" if is_reduced_ops_build(args) else "OFF"), + "-Donnxruntime_CLIENT_PACKAGE_BUILD=" + ("ON" if args.client_package_build else "OFF"), "-Donnxruntime_BUILD_MS_EXPERIMENTAL_OPS=" + ("ON" if args.ms_experimental else "OFF"), "-Donnxruntime_ENABLE_LTO=" + ("ON" if args.enable_lto else "OFF"), "-Donnxruntime_USE_ACL=" + ("ON" if args.use_acl else "OFF"), diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index ad27b8124c458..53d53f3e15e99 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -527,6 +527,15 @@ def add_size_reduction_args(parser: argparse.ArgumentParser) -> None: ) +def add_client_package_args(parser: argparse.ArgumentParser) -> None: + """Adds arguments for client package build package.""" + parser.add_argument( + "--client_package_build", + action="store_true", + help="Create ORT package with default settings more appropriate for client/on-device workloads.", + ) + + def add_python_binding_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for Python bindings.""" parser.add_argument("--enable_pybind", action="store_true", help="Enable Python bindings.") @@ -833,6 +842,7 @@ def convert_arg_line_to_args(self, arg_line: str) -> list[str]: # Use list[str] add_dependency_args(parser) add_extension_args(parser) add_size_reduction_args(parser) + add_client_package_args(parser) # Language Bindings add_python_binding_args(parser) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index ee7f8f2fa386a..e5e2a4749ef85 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index aa25e3f31166a..202aa61da0b80 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 7addb3217072a..69dc9d1a8f63d 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,7 +6,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index cf8bbbed70525..526ed71df2006 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index de024f0b3456f..b99246625cb77 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.0.250627 + default: 2.36.1.250708 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 4fa916db0de39..626a638121858 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index 84b6d30ee32ac..a87bb55441ac7 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -72,6 +72,8 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + - template: ../templates/set-version-number-variables-step.yml + # Reconstruct the build dir - task: PowerShell@2 displayName: 'PS: Extract nuget files gpu' @@ -114,6 +116,7 @@ stages: -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + -p:PackageVersion=$(OnnxRuntimeVersion) workingDirectory: '$(Build.SourcesDirectory)\csharp' - template: ../templates/win-esrp-dll.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 433250f05125e..e2c6b25f48b6d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.0.250627 + default: 2.36.1.250708 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index ab779e164b36e..74f7f782fe1b2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.0.250627' + default: '2.36.1.250708' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 110f83ff587c8..92e862bd79008 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.0.250627' + default: '2.36.1.250708' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 535784933a087..5b48a14e2afc3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 3e7427cc7a2e3..930dc83b73460 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.0.250627' + default: '2.36.1.250708' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index e3f549e2d649f..96eea6cd6d2fb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.0.250627' + default: '2.36.1.250708' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index d533fb7c83ddd..caee5367950e6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index cd060d1fbf19f..185f41822a7e5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 2a2ac49b4e073..9a1e7e5e251c9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 8528fa3907e96..5affc152a0a4a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 1406ce338f13e..29ebb8c4e4e61 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.36.0.250627' + QnnSdk: '2.36.1.250708' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false @@ -20,7 +20,7 @@ stages: name: ${{ parameters.qnn_ep_build_pool_name }} variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} - commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' + commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --client_package_build --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' steps: - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 78fce1f9b9602..7ebf5394e4530 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: 'BUILD_QNN_EP' @@ -50,7 +50,7 @@ jobs: matrix: SHARED_LIB: QnnLibKind: 'shared_lib' - ExtraQnnBuildArgs: '' + ExtraQnnBuildArgs: '--client_package_build' STATIC_LIB: QnnLibKind: 'static_lib' ExtraQnnBuildArgs: '' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index eb77c9422853d..ffeb577547f69 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: 'BUILD_QNN_EP' From 87a7ac04f3f9c5ed8c06e2bf6756e9fe2bc4c7b1 Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Thu, 17 Jul 2025 13:13:49 +0530 Subject: [PATCH 067/138] Revert "Sync ORT main 16 07 25 (#744)" (#745) This reverts commit bc3dc45c1a06cc514d393efad13a86df08f97f14. --- .github/workflows/windows_webgpu.yml | 2 - cmake/CMakeLists.txt | 1 - cmake/adjust_global_compile_flags.cmake | 5 - .../external/onnxruntime_external_deps.cmake | 25 +- cmake/onnxruntime_mlas.cmake | 1 - cmake/onnxruntime_providers_tensorrt.cmake | 23 +- cmake/onnxruntime_providers_webgpu.cmake | 11 +- cmake/vcpkg.json | 9 +- .../EndToEndTests.Mobile.Automation/Tests.cs | 4 +- .../TestResultProcessor.cs | 3 +- docs/ContribOperators.md | 10 +- docs/OperatorKernels.md | 6 +- include/onnxruntime/core/graph/graph.h | 5 +- .../core/providers/utils/ort_graph_to_proto.h | 718 ------------------ .../core/session/onnxruntime_c_api.h | 99 +-- .../core/session/onnxruntime_cxx_api.h | 2 - .../core/session/onnxruntime_cxx_inline.h | 9 - .../core/session/onnxruntime_ep_c_api.h | 90 +-- .../onnxruntime_session_options_config_keys.h | 4 +- js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 2 +- js/web/script/pull-prebuilt-wasm-artifacts.ts | 20 +- .../contrib_ops/cpu/bert/attention_common.h | 6 - .../contrib_ops/cpu/bert/attention_helper.h | 6 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 74 +- .../cpu/bert/group_query_attention.cc | 11 +- .../cpu/bert/group_query_attention_helper.h | 31 - .../cuda/bert/group_query_attention.cc | 6 - .../rocm/bert/group_query_attention.cu | 4 - .../webgpu/bert/group_query_attention.cc | 4 - onnxruntime/core/common/cpuid_info.cc | 2 +- onnxruntime/core/graph/abi_graph_types.h | 22 +- .../core/graph/contrib_ops/bert_defs.cc | 85 +-- onnxruntime/core/graph/ep_api_types.cc | 83 +- onnxruntime/core/graph/ep_api_types.h | 45 +- onnxruntime/core/graph/graph.cc | 4 - onnxruntime/core/graph/graph_viewer.cc | 12 +- .../core/graph/model_editor_api_types.h | 14 +- onnxruntime/core/mlas/inc/mlas.h | 16 - onnxruntime/core/mlas/lib/compute.cpp | 17 +- onnxruntime/core/mlas/lib/dequantize.cpp | 395 ---------- onnxruntime/core/mlas/lib/mlasi.h | 22 - onnxruntime/core/mlas/lib/platform.cpp | 2 - .../core/platform/windows/device_discovery.cc | 79 +- .../core/providers/cpu/math/softmax_shared.cc | 2 +- onnxruntime/core/providers/cpu/ml/ml_common.h | 2 +- .../cpu/quantization/quantize_linear.cc | 98 +-- .../providers/cuda/cuda_provider_factory.cc | 8 - .../src/ExecutionProvider.cpp | 5 +- .../nv_tensorrt_rtx/nv_execution_provider.cc | 2 +- .../qnn/builder/opbuilder/pool_op_builder.cc | 42 +- .../builder/opbuilder/simple_op_builder.cc | 14 +- .../qnn/builder/qnn_backend_manager.cc | 50 +- .../qnn/builder/qnn_backend_manager.h | 10 +- .../providers/qnn/qnn_execution_provider.cc | 59 +- .../providers/qnn/qnn_execution_provider.h | 7 +- .../providers/qnn/qnn_provider_factory.cc | 14 +- .../tensorrt_execution_provider_custom_ops.cc | 44 +- .../core/providers/webgpu/buffer_manager.cc | 25 +- .../core/providers/webgpu/tensor/cast.cc | 20 +- .../providers/webgpu/tensor/scatter_nd.cc | 22 +- .../core/providers/webgpu/tensor/slice.cc | 4 +- .../webgpu/webgpu_execution_provider.cc | 8 +- .../providers/webgpu/wgsl_templates/README.md | 4 +- .../webgpu/wgsl_templates/package-lock.json | 8 +- .../webgpu/wgsl_templates/package.json | 2 +- .../core/providers/webnn/builders/helper.cc | 126 ++- .../core/providers/webnn/builders/helper.h | 34 - .../builders/impl/argmax_min_op_builder.cc | 18 + .../webnn/builders/impl/base_op_builder.cc | 7 +- .../webnn/builders/impl/binary_op_builder.cc | 5 +- .../webnn/builders/impl/concat_op_builder.cc | 3 +- .../webnn/builders/impl/conv_op_builder.cc | 2 +- .../webnn/builders/impl/cumsum_op_builder.cc | 4 + .../webnn/builders/impl/dropout_op_builder.cc | 20 +- .../webnn/builders/impl/einsum_op_builder.cc | 90 +-- .../impl/gatherElements_op_builder.cc | 6 +- .../builders/impl/gatherND_op_builder.cc | 6 +- .../webnn/builders/impl/gather_op_builder.cc | 28 +- .../webnn/builders/impl/gemm_op_builder.cc | 44 +- .../webnn/builders/impl/gru_op_builder.cc | 3 +- .../webnn/builders/impl/logical_op_builder.cc | 4 +- .../webnn/builders/impl/lrn_op_builder.cc | 15 +- .../webnn/builders/impl/lstm_op_builder.cc | 3 +- .../builders/impl/matMulNBits_op_builder.cc | 19 +- .../webnn/builders/impl/max_min_op_builder.cc | 24 +- .../builders/impl/normalization_op_builder.cc | 87 ++- .../webnn/builders/impl/pool_op_builder.cc | 14 + .../webnn/builders/impl/qdq_op_builder.cc | 3 +- .../builders/impl/reduction_op_builder.cc | 8 +- .../webnn/builders/impl/reshape_op_builder.cc | 5 + .../impl/rotaryEmbedding_op_builder.cc | 2 +- .../impl/scatterElements_op_builder.cc | 6 +- .../builders/impl/scatterND_op_builder.cc | 6 +- .../webnn/builders/impl/slice_op_builder.cc | 21 +- .../webnn/builders/impl/softmax_op_builder.cc | 19 + .../impl/squeeze_unsqueeze_op_builder.cc | 3 + .../webnn/builders/impl/ternary_op_builder.cc | 3 +- .../webnn/builders/impl/tile_op_builder.cc | 9 + .../builders/impl/triangular_op_builder.cc | 9 + .../core/providers/webnn/builders/map_info.h | 4 +- .../providers/webnn/builders/model_builder.h | 2 +- onnxruntime/core/session/compile_api.cc | 30 - onnxruntime/core/session/compile_api.h | 2 - onnxruntime/core/session/ep_api_utils.h | 4 - .../core/session/ep_factory_internal.cc | 4 +- .../core/session/ep_factory_internal.h | 4 +- .../core/session/ep_library_internal.cc | 9 +- .../session/ep_library_provider_bridge.cc | 1 - onnxruntime/core/session/inference_session.cc | 12 - .../core/session/model_compilation_options.cc | 36 +- .../core/session/model_compilation_options.h | 10 - onnxruntime/core/session/onnxruntime_c_api.cc | 134 +--- onnxruntime/core/session/ort_apis.h | 10 +- .../core/session/provider_policy_context.cc | 8 +- onnxruntime/core/util/qmath.h | 49 -- onnxruntime/core/util/thread_utils.h | 6 - .../tools/quantization/base_quantizer.py | 2 +- .../python/tools/quantization/registry.py | 1 - .../transformers/fusion_attention_clip.py | 70 +- .../models/llama/requirements.txt | 2 +- .../models/whisper/convert_to_onnx.py | 2 +- .../models/whisper/requirements.txt | 2 +- .../models/whisper/whisper_decoder.py | 7 +- .../whisper/whisper_encoder_decoder_init.py | 4 +- .../models/whisper/whisper_helper.py | 4 +- .../models/whisper/whisper_inputs.py | 6 +- .../models/whisper/whisper_jump_times.py | 2 +- onnxruntime/test/autoep/library/ep.cc | 12 +- onnxruntime/test/autoep/library/ep.h | 6 +- onnxruntime/test/autoep/library/ep_factory.cc | 7 - onnxruntime/test/autoep/library/ep_factory.h | 2 - .../test/contrib_ops/matmul_4bits_test.cc | 6 +- onnxruntime/test/ep_graph/test_ep_graph.cc | 254 +------ .../test/ep_graph/test_ep_graph_utils.cc | 1 - .../test/ep_graph/test_ep_graph_utils.h | 1 - .../test/framework/ep_plugin_provider_test.cc | 14 +- .../test/mlas/bench/bench_computesoftmax.cpp | 4 +- .../mlas/unittest/test_dequantizelinear.cpp | 75 -- .../test/mlas/unittest/test_softmax.cpp | 4 +- .../test/providers/cpu/math/softmax_test.cc | 3 +- .../cpu/tensor/quantize_linear_test.cc | 26 - .../cpu/tensor/scatter_nd_op_test.cc | 11 - .../test/providers/qnn/qnn_ep_context_test.cc | 80 +- .../test/providers/qnn/simple_op_htp_test.cc | 32 - .../test/python/quantization/test_op_topk.py | 103 --- .../phi-4-v-instruct-vision-attention.onnx | Bin 7729 -> 0 bytes .../test/python/transformers/test_gqa_cpu.py | 559 +++----------- .../test/python/transformers/test_gqa_cuda.py | 3 +- .../transformers/test_paged_attention_cuda.py | 3 +- .../python/transformers/test_phi_vision.py | 70 +- .../three_layer_nested_subgraph_v2.onnx | Bin 1892 -> 0 bytes requirements-lintrunner.txt | 6 +- tools/ci_build/build.py | 3 - tools/ci_build/build_args.py | 10 - ...arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- .../c-api-noopenmp-packaging-pipelines.yml | 2 +- .../custom-nuget-packaging-pipeline.yml | 2 +- .../azure-pipelines/linux-qnn-ci-pipeline.yml | 2 +- .../azure-pipelines/py-packaging-pipeline.yml | 2 +- .../qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../stages/nuget-cuda-packaging-stage.yml | 3 - .../stages/py-cpu-packaging-stage.yml | 2 +- .../templates/android-java-api-aar-test.yml | 2 +- .../templates/android-java-api-aar.yml | 2 +- .../azure-pipelines/templates/c-api-cpu.yml | 2 +- .../templates/jobs/download_linux_qnn_sdk.yml | 2 +- .../templates/jobs/download_win_qnn_sdk.yml | 2 +- .../templates/py-linux-qnn.yml | 2 +- .../templates/py-win-arm64-qnn.yml | 2 +- .../templates/py-win-arm64ec-qnn.yml | 2 +- .../templates/py-win-x64-qnn.yml | 2 +- .../azure-pipelines/templates/qnn-ep-win.yml | 4 +- .../win-qnn-arm64-ci-pipeline.yml | 4 +- .../azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- 174 files changed, 877 insertions(+), 3985 deletions(-) delete mode 100644 include/onnxruntime/core/providers/utils/ort_graph_to_proto.h delete mode 100644 onnxruntime/core/mlas/lib/dequantize.cpp delete mode 100644 onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp delete mode 100644 onnxruntime/test/python/quantization/test_op_topk.py delete mode 100644 onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx delete mode 100644 onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 996e0d816d51a..70e8ea7e2792f 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -22,7 +22,6 @@ jobs: strategy: matrix: vcpkg_option: [novcpkg, vcpkg] - wgsl_template: [static, dynamic] env: OrtPackageId: Microsoft.ML.OnnxRuntime OnnxRuntimeBuildDirectory: ${{ github.workspace }} @@ -124,7 +123,6 @@ jobs: --build_nodejs ` --build_java ` --use_webgpu ` - --wgsl_template ${{ matrix.wgsl_template }} ` ${{ matrix.vcpkg_option == 'vcpkg' && '--use_vcpkg' || '' }} ` --cmake_extra_defines ` onnxruntime_BUILD_UNIT_TESTS=ON ` diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index b01110b2a4a03..fb4238731ffc3 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -151,7 +151,6 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) -option(onnxruntime_CLIENT_PACKAGE_BUILD "Enables default settings that are more appropriate for client/on-device workloads." OFF) cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 6d517003fa6b6..59d99ade131cd 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -95,11 +95,6 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() -# ORT build with default settings more appropriate for client/on-device workloads. -if (onnxruntime_CLIENT_PACKAGE_BUILD) - add_compile_definitions(ORT_CLIENT_PACKAGE_BUILD) -endif() - if (onnxruntime_ENABLE_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT ipo_enabled OUTPUT ipo_output) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 228906030d14c..e8f6bbe895d29 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -774,24 +774,13 @@ if (onnxruntime_USE_WEBGPU) endif() if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - if(onnxruntime_USE_VCPKG) - find_package(unofficial-duktape CONFIG REQUIRED) - add_library(duktape_static ALIAS unofficial::duktape::duktape) - else() - onnxruntime_fetchcontent_declare( - duktape - URL ${DEP_URL_duktape} - URL_HASH SHA1=${DEP_SHA1_duktape} - EXCLUDE_FROM_ALL - ) - onnxruntime_fetchcontent_makeavailable(duktape) - - if(NOT TARGET duktape_static) - add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") - target_compile_features(duktape_static PRIVATE c_std_99) - target_include_directories(duktape_static INTERFACE $) - endif() - endif() + onnxruntime_fetchcontent_declare( + duktape + URL ${DEP_URL_duktape} + URL_HASH SHA1=${DEP_SHA1_duktape} + EXCLUDE_FROM_ALL + ) + onnxruntime_fetchcontent_makeavailable(duktape) endif() endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 47e7779d93b33..f8f5546ae9465 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -31,7 +31,6 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp ${MLAS_SRC_DIR}/compute.cpp - ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp ${MLAS_SRC_DIR}/qladd.cpp diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 4184e0b049afc..69c81a5ec7b9d 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -72,9 +72,10 @@ endif() # TensorRT 10 GA onwards, the TensorRT libraries will have major version appended to the end on Windows, - # for example, nvinfer_10.dll, nvonnxparser_10.dll ... + # for example, nvinfer_10.dll, nvinfer_plugin_10.dll, nvonnxparser_10.dll ... if (WIN32 AND TRT_GREATER_OR_EQUAL_TRT_10_GA) set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}") + set(NVINFER_PLUGIN_LIB "nvinfer_plugin_${NV_TENSORRT_MAJOR}") set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}") endif() @@ -82,11 +83,15 @@ set(NVINFER_LIB "nvinfer") endif() + if (NOT NVINFER_PLUGIN_LIB) + set(NVINFER_PLUGIN_LIB "nvinfer_plugin") + endif() + if (NOT PARSER_LIB) set(PARSER_LIB "nvonnxparser") endif() - MESSAGE(STATUS "Looking for ${NVINFER_LIB}") + MESSAGE(STATUS "Looking for ${NVINFER_LIB} and ${NVINFER_PLUGIN_LIB}") find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB} HINTS ${TENSORRT_ROOT} @@ -96,6 +101,14 @@ MESSAGE(STATUS "Can't find ${NVINFER_LIB}") endif() + find_library(TENSORRT_LIBRARY_INFER_PLUGIN ${NVINFER_PLUGIN_LIB} + HINTS ${TENSORRT_ROOT} + PATH_SUFFIXES lib lib64 lib/x64) + + if (NOT TENSORRT_LIBRARY_INFER_PLUGIN) + MESSAGE(STATUS "Can't find ${NVINFER_PLUGIN_LIB}") + endif() + if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) MESSAGE(STATUS "Looking for ${PARSER_LIB}") @@ -107,7 +120,7 @@ MESSAGE(STATUS "Can't find ${PARSER_LIB}") endif() - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_NVONNXPARSER}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") else() if (TRT_GREATER_OR_EQUAL_TRT_10_GA) @@ -140,7 +153,7 @@ endif() # Static libraries are just nvonnxparser_static on all platforms set(onnxparser_link_libs nvonnxparser_static) - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") endif() @@ -148,7 +161,7 @@ # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 # However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries. - # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER}. + # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}. if(onnxruntime_CUDA_MINIMAL) set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) else() diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 2865ad33b39f4..5b80b1262464d 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -172,12 +172,10 @@ file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR}) # Find all WGSL template input files - file(GLOB_RECURSE WGSL_TEMPLATE_FILES - "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template" - "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template") + file(GLOB_RECURSE WGSL_TEMPLATE_FILES "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template") # Set wgsl-gen command line options as a list - set(WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu/" "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") + set(WGSL_GEN_OPTIONS "-i" "../" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") if (onnxruntime_WGSL_TEMPLATE STREQUAL "static") if (CMAKE_BUILD_TYPE STREQUAL "Debug") list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal") @@ -209,9 +207,10 @@ # Add the generated directory to include paths target_include_directories(onnxruntime_providers_webgpu PRIVATE ${WGSL_GENERATED_ROOT}) elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") + add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") + target_compile_features(duktape_static PRIVATE c_std_99) target_link_libraries(onnxruntime_providers_webgpu duktape_static) - onnxruntime_add_include_to_target(onnxruntime_providers_webgpu duktape_static) - + target_include_directories(onnxruntime_providers_webgpu PRIVATE ${duktape_SOURCE_DIR}/src) # Define the path to the generated templates.js file target_compile_definitions(onnxruntime_providers_webgpu PRIVATE "ORT_WGSL_TEMPLATES_JS_PATH=\"${WGSL_GENERATED_TEMPLATES_JS}\"") diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index 373ecec440921..7c6b2fed36d1b 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -43,6 +43,7 @@ "ms-gsl", "nlohmann-json", "onnx", + "optional-lite", { "name": "protobuf", "version>=": "3.21.12" @@ -93,10 +94,6 @@ "webgpu-ep": { "description": "Build with WebGPU EP", "dependencies": [] - }, - "webgpu-ep-wgsl-template-dynamic": { - "description": "Build with WebGPU EP with dynamic WGSL template code generator", - "dependencies": ["duktape"] } }, "overrides": [ @@ -107,10 +104,6 @@ { "name": "flatbuffers", "version": "23.5.26" - }, - { - "name": "duktape", - "version": "2.7.0#2" } ] } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs index 6e6190b8227b8..c28830ec72157 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs @@ -40,12 +40,10 @@ public void RunPlatformUnitTest() var serializedResultSummary = _app.Invoke(_getResultsBackdoorMethodName)?.ToString(); Assert.IsNotEmpty(serializedResultSummary, "Test results were not returned"); - // Fix security issue (overflow with too much nesting): GHSA-5crp-9r3c-p9vr - JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; var testSummary = JsonConvert.DeserializeObject(serializedResultSummary); Assert.AreEqual(testSummary.Failed, 0, $"{testSummary.Failed} tests failed"); _app.Screenshot("Post-testing"); } } -} +} \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs index 625cc2c54055c..8419d261e4a41 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs @@ -45,9 +45,8 @@ public TestResultSummary GetResults() public string GetSerializedResults() { var resultSummary = GetResults(); - JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; var serializedResultSummary = JsonConvert.SerializeObject(resultSummary, Formatting.Indented); return serializedResultSummary; } } -} +} \ No newline at end of file diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f3dcde1abe37a..b80918e6615e1 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2545,8 +2545,6 @@ This version of the operator has been available since version 1 of the 'com.micr
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
-
qk_output : int
-
Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).
rotary_interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
@@ -2557,7 +2555,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Softcap value for attention weights. Default value is 0.
-#### Inputs (7 - 12) +#### Inputs (7 - 11)
query : T
@@ -2582,11 +2580,9 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element
attention_bias (optional) : T
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
-
head_sink (optional) : T
-
1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.
-#### Outputs (3 - 4) +#### Outputs
output : T
@@ -2595,8 +2591,6 @@ This version of the operator has been available since version 1 of the 'com.micr
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
present_value : T
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-
output_qk (optional) : T
-
Values of QK matrix multiplication, either before or after softmax normalization
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index fa6c731231405..1ffcabee8cc10 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -538,7 +538,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -942,7 +942,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1420,7 +1420,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index c18a42cc1bbc1..54e03a31fceef 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -952,12 +952,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return const_cast(this)->GetNodeArg(name); } - // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg + // search this and up through any parent_graph_ instance for a NodeArg NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name); - // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg - const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const; - /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found. @param name The NodeArg name. @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created. diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h deleted file mode 100644 index 37665542f614f..0000000000000 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ /dev/null @@ -1,718 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -/* - SUMMARY: - Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider - implementations that need to convert an OrtGraph instance into an ONNX protobuf model. - - Users may copy this file and modify as needed. - - USAGE: - This is a header-only implementation that includes both the function declarations and definitions. Copy this file - into a project that links with both ONNX Runtime and ONNX. - - Define the ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL preprocessor macro before the #include statement in exactly one C++ - file to define the implementation. Example: - - #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - #include "ort_graph_to_proto.h" - - Other compilation units that depend on these utilities should include this file without defining the - preprocessor macro. - - Example program snippets are shown below. Refer to the function declarations for detailed usage information. - - EXAMPLE SNIPPET (initializers stored within TensorProto): - - ```C++ - #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - #include "ort_graph_to_proto.h" - - OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, - OrtEpGraphSupportInfo* graph_support_info) { - onnx::GraphProto graph_proto; - OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto); - - // graph_proto stores initializers internally - } - ``` - - EXAMPLE SNIPPET (large initializers stored in external file): - - ```C++ - #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - #include "ort_graph_to_proto.h" - - OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, - OrtEpGraphSupportInfo* graph_support_info) { - std::string external_file_path = "weights.bin"; - std::ofstream out_file(external_file_path, std::ios::binary); - - auto handle_initializer_data = [&external_file_path, &out_file](const OrtValueInfo* value_info, - const void* data, size_t bytes, - bool& is_external, std::string& location, - int64_t& offset) -> Ort::Status { - // OrtValueInfo* could be used to query initializer's name, type, shape, consumers, etc. - (void)value_info; - - if (bytes <= 127) { - is_external = false; // Keep small initializers stored inside the TensorProto. - return Ort::Status{nullptr}; - } - - offset = out_file.tellp(); - location = external_file_path; - out_file.write(static_cast(data), bytes); - out_file.flush(); - is_external = true; // True if is external initializer - return Ort::Status{nullptr}; - } - - ONNX_NAMESPACE::GraphProto graph_proto; - OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); - - // graph_proto stores large initializers in an external file - } - ``` -*/ - -#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ -#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ - -#include -#include "core/session/onnxruntime_cxx_api.h" -#include "onnx/onnx_pb.h" - -namespace OrtEpUtils { - -/// -/// Signature of user-provided function to handle initializer data. Called by OrtGraphToProto() for every initializer. -/// -/// If the function sets the `is_external` output parameter to false, OrtGraphToProto() stores initializer data -/// within the TensorProto as raw_data. -/// -/// Otherwise, if the function sets `is_external` to true, OrtGraphToProto() assumes that this function stores the -/// initializer data in a file. In this case, OrtGraphToProto() configures the corresponding TensorProto to point the -/// location and offset returned via the `location` and `offset` output parameters. -/// -/// It is recommended to keep small initializers with byte size <= 127 stored inline the TensorProto to ensure -/// ONNX shape inference works correctly with the serialized ONNX model. -/// -/// OrtValueInfo for the initializer. Can be used to query name, type, shape, -/// and consumer nodes. -/// Opaque pointer to the initializer data. -/// Size in bytes of the initializer data. -/// Output parameter set to true if the initializer data is stored externally. The -/// implementer is responsible for writing the initializer data to file. If set to false, -/// the initializer will be stored within the TensorProto. -/// Output parameter set to the location (e.g., file) into which the initializer is stored -/// by the implementer of this function. Ignored if `is_external` is set to false. -/// Output parameter set to the offset (e.g., file offset) into which the initializer is stored -/// by the implementer of this function. Ignored if `is_external` is set to false. -/// An Ort::Status indicating success or an error. Serialization exits if this returns an error. -using HandleInitializerDataFunc = std::function; - -/// -/// Serializes the provided OrtGraph to a onnx::GraphProto. -/// Allows the caller to provide a function that specifies whether an initializer should be stored -/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). -/// -/// OrtGraph instance to serialize. -/// Destination GraphProto into which to serialize the input OrtGraph. -/// Optional function called to allow the user to determine -/// where the initializer data is stored. -/// An Ort::Status indicating success or an error. -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::GraphProto& graph_proto, - HandleInitializerDataFunc handle_initializer_data_func = nullptr); - -/// -/// Serializes the provided top-level OrtGraph to a onnx::ModelProto. -/// Allows the caller to provide a function that specifies whether an initializer should be stored -/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). -/// -/// OrtGraph instance to serialize. -/// Destination ModelProto into which to serialize the input OrtGraph. -/// Optional function called to allow the user to determine -/// where the initializer data is stored. -/// An Ort::Status indicating success or an error. -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::ModelProto& model_proto, - HandleInitializerDataFunc handle_initializer_data_func = nullptr); -} // namespace OrtEpUtils - -// End of header -#endif // INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ - -// -// IMPLEMENTATION BELOW -// -#ifdef ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - -#include -#include -#include -#include -#include -#include - -#define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ - do { \ - OrtStatus* _status = (fn); \ - if (_status != nullptr) { \ - return Ort::Status{_status}; \ - } \ - } while (0) - -#define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ - do { \ - Ort::Status _status = (fn); \ - if (!_status.IsOK()) { \ - return _status; \ - } \ - } while (0) - -#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ - do { \ - if ((cond)) { \ - return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ - } \ - } while (0) - -namespace OrtEpUtils { - -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, - bool get_symbolic_dims, - /*out*/ ONNXTensorElementDataType& elem_type, - /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims); -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); - -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::GraphProto& graph_proto, - HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // - // Set GraphProto metadata - // - const char* graph_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); - graph_proto.set_name(graph_name); - graph_proto.set_doc_string("Serialized from OrtGraph"); - - // - // Set GraphProto inputs and outputs - // - size_t num_graph_inputs = 0; - size_t num_graph_outputs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); - - std::vector graph_inputs(num_graph_inputs); - std::vector graph_outputs(num_graph_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); - - for (const OrtValueInfo* ort_value_info : graph_inputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - for (const OrtValueInfo* ort_value_info : graph_outputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - // - // Set GraphProto nodes, value_infos, and initializers. - // - - // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. - // A std::map maintains its elements in a stable ordering. - std::map value_infos; // For GraphProto.value_info - std::map initializer_value_infos; // For GraphProto.initializer - - // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. - // Optionally returns the OrtValueInfo name to the caller. - auto collect_value_info = [&ort_api, &value_infos, - &initializer_value_infos](const OrtValueInfo& ort_value_info, - /*out*/ const char** value_name_out = nullptr) -> Ort::Status { - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - - if (value_name_out != nullptr) { - *value_name_out = value_name; - } - - if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { - return Ort::Status{nullptr}; // Already processed this OrtValueInfo. - } - - bool is_required_graph_input = false; - bool is_optional_graph_input = false; - bool is_graph_output = false; - bool is_constant_initializer = false; - bool is_from_outer_scope = false; - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); - - // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. - // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. - // For values defined in an outer scope, just add the value info but not the initializer. - if (is_from_outer_scope) { - value_infos.emplace(value_name, &ort_value_info); - } else if (is_optional_graph_input) { - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (is_constant_initializer) { - value_infos.emplace(value_name, &ort_value_info); - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (!is_required_graph_input && !is_graph_output) { - value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. - } - - return Ort::Status{nullptr}; - }; - - size_t num_nodes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); - - std::vector nodes(num_nodes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); - - // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos - // that will be stored in GraphProto.value_info and GraphProto.initializer. - for (size_t i = 0; i < num_nodes; i++) { - const OrtNode* ort_node = nodes[i]; - onnx::NodeProto* node_proto = graph_proto.add_node(); - - const char* node_name = nullptr; - const char* node_domain = nullptr; - const char* node_op_type = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); - - node_proto->set_name(node_name); - node_proto->set_domain(node_domain); - node_proto->set_op_type(node_op_type); - - size_t num_inputs = 0; - size_t num_implicit_inputs = 0; - size_t num_outputs = 0; - size_t num_attrs = 0; - size_t num_subgraphs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); - - // Handle node attributes - if (num_attrs > 0) { - std::vector ort_attrs(num_attrs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); - - for (const OrtOpAttr* ort_attr : ort_attrs) { - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - - Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; - if (!status.IsOK()) { - // This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it. - // Can use Node_GetSubgraphs to get subgraphs. - continue; - } - - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); - } - } - - // Handle node subgraphs - if (num_subgraphs > 0) { - std::vector ort_subgraphs(num_subgraphs); - std::vector subgraph_attr_names(num_subgraphs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), - subgraph_attr_names.data())); - - for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { - const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; - const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; - - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); - - attr_proto->set_name(subgraph_attr_name); - attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); - } - } - - // Handle node inputs - if (num_inputs > 0) { - std::vector ort_inputs(num_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_inputs) { - if (ort_value_info == nullptr) { - // missing optional input. - node_proto->add_input(""); - continue; - } - - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_input(value_name); - } - } - - // Handle implicit inputs to this node. - if (num_implicit_inputs > 0) { - std::vector ort_implicit_inputs(num_implicit_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), - ort_implicit_inputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { - assert(ort_value_info != nullptr); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); - } - } - - // Handle node outputs - if (num_outputs > 0) { - std::vector ort_outputs(num_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_outputs) { - if (ort_value_info == nullptr) { - // missing optional output. - node_proto->add_output(""); - continue; - } - - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_output(value_name); - } - } - } - - // Add value_infos to GraphProto as ValueInfoProto objects. - for (const std::pair& entry : value_infos) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); - } - - // Add initializers to GraphProto as TensorProto objects. - for (const std::pair& entry : initializer_value_infos) { - const OrtValueInfo* initializer_value_info = entry.second; - std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. - std::vector initializer_dims; - std::vector initializer_sym_dims; - ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, - initializer_elem_type, initializer_dims, - initializer_sym_dims)); - - onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); - tensor_proto->set_name(initializer_name); - tensor_proto->set_data_type(initializer_elem_type); - - auto* tensor_proto_dims = tensor_proto->mutable_dims(); - for (int64_t dim : initializer_dims) { - tensor_proto_dims->Add(dim); - } - - const OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); - - const void* data = nullptr; - size_t data_bytes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); - - std::string ext_location; - int64_t ext_offset = 0; - bool is_external = false; - - if (handle_initializer_data_func != nullptr) { - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, - is_external, ext_location, ext_offset)); - } - - if (is_external) { - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); - auto* ext_data_entries = tensor_proto->mutable_external_data(); - onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); - - location_entry->set_key("location"); - location_entry->set_value(ext_location); - offset_entry->set_key("offset"); - offset_entry->set_value(std::to_string(ext_offset)); - } else { - // User wants to store data inline the TensorProto's raw_data - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); - tensor_proto->set_raw_data(data, data_bytes); - } - } - - return Ort::Status{nullptr}; -} - -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::ModelProto& model_proto, - HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // Check that OrtGraph is a top-level graph (no parent node). - const OrtNode* parent_node = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); - ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); - - // Set model description. - model_proto.set_doc_string("Serialized from OrtGraph"); - model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); - - // Set ir version. - int64_t ir_version = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); - model_proto.set_ir_version(ir_version); - - // Set operator sets. - size_t num_operator_sets = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); - ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); - - std::vector domains(num_operator_sets, nullptr); - std::vector opset_versions(num_operator_sets); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), - num_operator_sets)); - - auto* operator_sets = model_proto.mutable_opset_import(); - - for (size_t i = 0; i < num_operator_sets; ++i) { - onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); - operator_set->set_domain(domains[i]); - operator_set->set_version(opset_versions[i]); - } - - model_proto.clear_graph(); - onnx::GraphProto* graph_proto = model_proto.mutable_graph(); - - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); - - return Ort::Status{nullptr}; -} - -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, - bool get_symbolic_dims, - /*out*/ ONNXTensorElementDataType& elem_type, - /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims) { - const OrtApi& ort_api = Ort::GetApi(); - - const OrtTypeInfo* ort_type_info = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); - - ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); - ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); - - const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; - ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); - - size_t num_dims = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); - - std::vector ort_dims(num_dims, 0); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); - - elem_type = ort_elem_type; - dims = std::move(ort_dims); - - if (get_symbolic_dims) { - std::vector ort_dim_syms(num_dims, nullptr); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), - ort_dim_syms.size())); - - symbolic_dims.reserve(num_dims); - for (const char* sym_dim : ort_dim_syms) { - symbolic_dims.push_back(sym_dim); - } - } - - return Ort::Status{nullptr}; -} - -// Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, - onnx::ValueInfoProto& value_info_proto) { - const OrtApi& ort_api = Ort::GetApi(); - - std::vector ort_dims; - std::vector ort_dim_syms; - ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - - // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, - ort_elem_type, ort_dims, ort_dim_syms)); - - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - value_info_proto.set_name(value_name); - - onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); - type_proto_tensor->set_elem_type(ort_elem_type); - - onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); - - for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { - onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); - - if (ort_dims[dim_idx] >= 0) { - dim_proto->set_dim_value(ort_dims[dim_idx]); - } else { - const std::string& dim_param = ort_dim_syms[dim_idx]; - - // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, - // which represents an unknown dimension. - if (!dim_param.empty()) { - dim_proto->set_dim_param(dim_param); - } - } - } - - return Ort::Status{nullptr}; -} - -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { - const OrtApi& ort_api = Ort::GetApi(); - - const char* attr_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); - attr_proto.set_name(attr_name); - - size_t total_attr_bytes = 0; - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); - - switch (attr_type) { - case OrtOpAttrType::ORT_OP_ATTR_INT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); - - int64_t i_val = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); - attr_proto.set_i(i_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_INTS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector i_vals(total_attr_bytes / sizeof(int64_t)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* ints = attr_proto.mutable_ints(); - for (int64_t val : i_vals) { - ints->Add(val); - } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); - - float f_val = 0.0f; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); - attr_proto.set_f(f_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector f_vals(total_attr_bytes / sizeof(float)); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* floats = attr_proto.mutable_floats(); - for (float val : f_vals) { - floats->Add(val); - } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRING: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::string* str = attr_proto.mutable_s(); - - str->resize(total_attr_bytes, '\0'); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, - &total_attr_bytes)); - - str->resize(total_attr_bytes - 1); // remove extra ending terminating '\0' character. - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector chars(total_attr_bytes, '\0'); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* strs = attr_proto.mutable_strings(); - - // Strings are all in a single buffer, each separated with a '\0'. - // Extract each string and add it to the STRINGS attribute array. - char* at = chars.data(); - char* end = at + chars.size(); - - while (at < end) { - char* str_begin = at; - - while (*at && at < end) { - at++; - } - - strs->Add()->assign(str_begin, at - str_begin); - if (at < end) { - assert(*at == '\0'); - at++; // Skip '\0' to get to the beginning of the next string. - } - } - - break; - } - default: { - std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); - } - } - - return Ort::Status{nullptr}; -} - -} // namespace OrtEpUtils -#endif // ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 82e782112974f..86c0b60db2bc4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -66,7 +66,6 @@ extern "C" { #define _In_reads_(X) #define _Inout_updates_(X) #define _Out_writes_(X) -#define _Out_writes_opt_(X) #define _Inout_updates_all_(X) #define _Out_writes_bytes_all_(X) #define _Out_writes_all_(X) @@ -4750,8 +4749,6 @@ struct OrtApi { * \param[in] len Number of bytes allowed to store in data * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success * - * \note Does not support reading graph attributes. Refer to Node_GetSubgraphs. - * * \since Version 1.17. */ ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); @@ -5571,45 +5568,6 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); - /** \brief Returns the number of operator sets that the graph's model uses. - * - * \note An operator set is uniquely identified by the (domain, opset_version) pair. All models must have at - * least one entry that specifies which entry of the ONNX operator set is used. The ONNX domain is represented by - * an empty string. - * - * \param[in] graph The OrtGraph instance. - * \param[out] num_operator_sets Output parameter set to the number of operator sets that the graph's model uses. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); - - /** \brief Returns the operator sets that the graph's model uses. - * - * \note An operator set is uniquely identified by the (domain, opset_version) pair. All models must have at - * least one entry that specifies which entry of the ONNX operator set is used. The ONNX domain is represented by - * an empty string. - * - * \param[in] graph The OrtGraph instance. - * \param[out] domains Pre-allocated array of `num_operator_sets` elements that is filled with - * null-terminated domain names. - * \param[out] opset_versions Pre-allocated array of `num_operator_sets` elements that is filled with - * the opset version of the corresponding domain in the `domains` array. - * \param[in] num_operator_sets The size of the `domains` and `opset_versions` arrays. - * Typical usage sets this to the result of Graph_GetNumOperatorSets(). - * An error status is returned if `num_operator_sets` is less than the actual number - * of operator sets. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(Graph_GetOperatorSets, _In_ const OrtGraph* graph, - _Out_writes_(num_operator_sets) const char** domains, - _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets); - /** \brief Returns the number of graph inputs. * * \note The count includes initializers that are included in the list of graph inputs. @@ -5748,24 +5706,6 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); - /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. - * - * Note: - * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference - * the same underlying graph. - * - * \param[in] src_graph The source OrtGraph instance. - * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. - * \param[in] num_nodes Number of nodes. - * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes, - _In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph); - /// @} /// \name OrtNode @@ -5993,24 +5933,20 @@ struct OrtApi { /** \brief Get the subgraphs, as OrtGraph instances, contained by the given node. * - * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. ONNX nodes store subgraphs in - * their attributes, however, this function must be used to obtain subgraphs from an OrtNode. + * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. * * \param[in] node The OrtNode instance. * \param[out] subgraphs Pre-allocated array of `num_subgraphs` elements that is filled with the node's subgraphs. * \param[in] num_subgraphs The size of the `num_subgraphs` array. * Typical usage sets this to the result of Node_GetNumSubgraphs(). An error status is * returned if `num_subgraphs` is less than the number of node subgraphs. - * \param[out] attribute_names Optional pre-allocated array of `num_subgraphs` elements that is filled with the - * attribute names that correspond to the subgraphs. Ignored if set to NULL. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, - _Out_writes_opt_(num_subgraphs) const char** attribute_names); + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); /** \brief Get the node's parent OrtGraph instance. * @@ -6026,19 +5962,6 @@ struct OrtApi { */ ORT_API2_STATUS(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); - /** \brief Returns the execution provider name that this node is assigned to run on. - * Returns NULL if the node has not been assigned to any execution provider yet. - * For plugin execution providers, the name is the one returned by OrtEp::GetName. - * - * \param[in] node The OrtNode instance. - * \param[out] out Output execution provider type and can be NULL if node has not been assigned. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); - /// @} /// \name OrtRunOptions @@ -6887,24 +6810,6 @@ struct OrtCompileApi { */ ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, size_t flags); - - /** Sets information related to EP context binary file. - * - * EP uses this information to decide the location and context binary file name. - * Used while compiling model with input and output in memory buffer - * - * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] output_directory Null terminated string of the path (wchar on Windows, char otherwise). - * \param[in] model_name Null terminated string of the model name (wchar on Windows, char otherwise). - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(ModelCompilationOptions_SetEpContextBinaryInformation, - _In_ OrtModelCompilationOptions* model_compile_options, - _In_ const ORTCHAR_T* output_directory, - _In_ const ORTCHAR_T* model_name); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d1b08f127fa2a..c59baa59c91a5 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1161,8 +1161,6 @@ struct ModelCompilationOptions : detail::Base { size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer - ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, - const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index ba5d53e6c2dd0..612adc81d3309 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -819,15 +819,6 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath( return *this; } -inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextBinaryInformation( - const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name) { - Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextBinaryInformation( - this->p_, - output_directory, - model_name)); - return *this; -} - inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalInitializersFile( const ORTCHAR_T* file_path, size_t initializer_size_threshold) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelExternalInitializersFile( diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 5d00ce4940d02..44c7bb6ee424a 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -358,7 +358,7 @@ struct OrtEp { * * \since Version 1.22. */ - ORT_API_T(const char*, GetName, _In_ const OrtEp* this_ptr); + const char*(ORT_API_CALL* GetName)(_In_ const OrtEp* this_ptr); /** \brief Get information about the nodes supported by the OrtEp instance. * @@ -376,8 +376,8 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(GetCapability, _In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, - _Inout_ OrtEpGraphSupportInfo* graph_support_info); + OrtStatus*(ORT_API_CALL* GetCapability)(_In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, + _Inout_ OrtEpGraphSupportInfo* graph_support_info); /** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance * for each OrtGraph in order to define its computation function. @@ -416,10 +416,10 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(Compile, _In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, - _In_ const OrtNode** fused_nodes, _In_ size_t count, - _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes); + OrtStatus*(ORT_API_CALL* Compile)(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes); /** \brief Release OrtNodeComputeInfo instances. * @@ -429,9 +429,9 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API_T(void, ReleaseNodeComputeInfos, _In_ OrtEp* this_ptr, - OrtNodeComputeInfo** node_compute_infos, - _In_ size_t num_node_compute_infos); + void(ORT_API_CALL* ReleaseNodeComputeInfos)(_In_ OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + _In_ size_t num_node_compute_infos); /** \brief Get the EP's preferred data layout. * @@ -445,7 +445,8 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(GetPreferredDataLayout, _In_ OrtEp* this_ptr, _Out_ OrtEpDataLayout* preferred_data_layout); + OrtStatus*(ORT_API_CALL* GetPreferredDataLayout)(_In_ OrtEp* this_ptr, + _Out_ OrtEpDataLayout* preferred_data_layout); /** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout * should be converted to `target_data_layout`. @@ -469,10 +470,11 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(ShouldConvertDataLayoutForOp, _In_ OrtEp* this_ptr, - _In_z_ const char* domain, _In_z_ const char* op_type, - _In_ OrtEpDataLayout target_data_layout, - _Outptr_ int* should_convert); + OrtStatus*(ORT_API_CALL* ShouldConvertDataLayoutForOp)(_In_ OrtEp* this_ptr, + _In_z_ const char* domain, + _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert); /** \brief Set dynamic options on this EP. * @@ -490,10 +492,10 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(SetDynamicOptions, _In_ OrtEp* this_ptr, - _In_reads_(num_options) const char* const* option_keys, - _In_reads_(num_options) const char* const* option_values, - _In_ size_t num_options); + OrtStatus*(ORT_API_CALL* SetDynamicOptions)(_In_ OrtEp* this_ptr, + _In_reads_(num_options) const char* const* option_keys, + _In_reads_(num_options) const char* const* option_values, + _In_ size_t num_options); /** \brief Called by ORT to notify the EP of the start of a run. * @@ -506,7 +508,8 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(OnRunStart, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options); + OrtStatus*(ORT_API_CALL* OnRunStart)(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options); /** \brief Called by ORT to notify the EP of the end of a run. * @@ -521,7 +524,9 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(OnRunEnd, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options, _In_ bool sync_stream); + OrtStatus*(ORT_API_CALL* OnRunEnd)(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options, + _In_ bool sync_stream); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. @@ -581,7 +586,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - ORT_API_T(const char*, GetName, const OrtEpFactory* this_ptr); + const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); /** \brief Get the name of vendor who owns the execution provider that the factory creates. * @@ -592,7 +597,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - ORT_API_T(const char*, GetVendor, const OrtEpFactory* this_ptr); // return EP vendor + const char*(ORT_API_CALL* GetVendor)(const OrtEpFactory* this_ptr); // return EP vendor /** \brief Get information from the execution provider about OrtHardwareDevice support. * @@ -611,12 +616,12 @@ struct OrtEpFactory { * * \since Version 1.22. */ - ORT_API2_STATUS(GetSupportedDevices, _In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices); + OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); /** \brief Function to create an OrtEp instance for use in a Session. * @@ -642,12 +647,12 @@ struct OrtEpFactory { * * \since Version 1.22. */ - ORT_API2_STATUS(CreateEp, _In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); + OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); /** \brief Release the OrtEp instance. * @@ -656,18 +661,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - ORT_API_T(void, ReleaseEp, OrtEpFactory* this_ptr, struct OrtEp* ep); - - /** \brief Get the vendor id who owns the execution provider that the factory creates. - * - * This is typically the PCI vendor ID. See https://pcisig.com/membership/member-companies - * - * \param[in] this_ptr The OrtEpFactory instance. - * \return vendor_id The vendor ID of the execution provider the factory creates. - * - * \since Version 1.23. - */ - ORT_API_T(uint32_t, GetVendorId, const OrtEpFactory* this_ptr); + void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); /** \brief Get the version of the execution provider that the factory creates. * @@ -681,7 +675,7 @@ struct OrtEpFactory { * * \since Version 1.23. */ - ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); + const char*(ORT_API_CALL* GetVersion)(_In_ const OrtEpFactory* this_ptr); /** \brief Create an OrtAllocator for the given OrtMemoryInfo. * diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 314cf76cc8044..97e53e6acee5a 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -148,9 +148,7 @@ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = " // Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking // "0": thread will block if found no job to run -// "1": thread will spin a number of times before blocking -// The default is "0" when ORT is built with "ORT_CLIENT_PACKAGE_BUILD" and "1" otherwise. -// Thread spinning is disabled by default for client/on-device workloads to reduce cpu utilization and improve power efficiency. +// "1": default, thread will spin a number of times before blocking static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning"; static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning"; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index c2085342efd80..5a837fd1e0bfa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -98,7 +98,7 @@ const calculateInputIndicesImpl = ( `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { var input_indices: ${input.type.indices}; var carry = 0u; - for (var i = ${inputShape.length - 1}; i >= 0; i--) { + for (var i = ${inputShape.length}; i >= 0; i--) { let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)}; let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)}; diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index 87008f51ff4b9..c3300f7272bb9 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -38,6 +38,7 @@ Usage: Options: -d --debug specify the debug build type of the artifacts to download. -l --latest if set, will always use the latest build, even if it is not completed yet. + --webgpu-ep if set, will use the webgpu EP wasm build instead of the default(JSEP) one. -h --help print this message and exit `; @@ -80,8 +81,9 @@ try { // The following code checks both the command line arguments and the npm_config_* environment variables to get the correct values. const debug = args.debug || process.env.npm_config_d || process.env.npm_config_debug; const latest = args.latest || process.env.npm_config_l || process.env.npm_config_latest; +const webgpuEp = args['webgpu-ep'] || process.env.npm_config_webgpu_ep; -const folderName = debug ? 'Debug_wasm' : 'Release_wasm'; +const folderName = (debug ? 'Debug_wasm' : 'Release_wasm') + (webgpuEp ? '_webgpu' : ''); const allowImcomplete = latest; const run = args._[0]; // The first non-option argument @@ -149,17 +151,13 @@ async function downloadArtifactsForRun(run: any): Promise { if (!fs.existsSync(WASM_FOLDER)) { fs.mkdirSync(WASM_FOLDER); } else { + // TODO: revise artifacts download + const filesToDelete = ['ort-wasm-simd-threaded.jsep.mjs', 'ort-wasm-simd-threaded.jsep.wasm']; + if (!folderName.endsWith('_webgpu')) { + filesToDelete.push('ort-wasm-simd-threaded.mjs', 'ort-wasm-simd-threaded.wasm'); + } fs.readdirSync(WASM_FOLDER).forEach((file) => { - if ( - [ - 'ort-wasm-simd-threaded.jsep.mjs', - 'ort-wasm-simd-threaded.jsep.wasm', - 'ort-wasm-simd-threaded.jsep.mjs', - 'ort-wasm-simd-threaded.jsep.wasm', - 'ort-wasm-simd-threaded.mjs', - 'ort-wasm-simd-threaded.wasm', - ].includes(file) - ) { + if (filesToDelete.includes(file)) { const filePath = path.join(WASM_FOLDER, file); console.log(`Deleting old file: ${filePath}`); fs.unlinkSync(filePath); diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 80d374d3f0b25..243f611da49e1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -53,12 +53,6 @@ enum AttentionKernelType { AttentionKernel_Default }; -enum class QKOutputType : int { - NO_OUTPUT = 0, - BEFORE_SOFTMAX = 1, - AFTER_SOFTMAX = 2 -}; - constexpr bool LAYOUT_BSNH = false; constexpr bool LAYOUT_BNSH = true; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index aef47edd5fcd2..ac32a4445f3ca 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -17,13 +17,13 @@ namespace onnxruntime { namespace contrib { template -inline void ComputeSmoothSoftmaxInplace(T* score, int D, float sink, ThreadPool* tp) { - MlasComputeSoftmax(score, score, 1, D, false, true, sink, tp); +inline void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { + MlasComputeSoftmax(score, score, N, D, false, true, tp); } template inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp); + MlasComputeSoftmax(score, score, N, D, false, false, tp); } template diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 0d5117709c18a..c79508cbae273 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -35,8 +35,6 @@ class GQAAttentionBase { use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; local_window_size_ = has_local ? static_cast(info.GetAttrOrDefault("local_window_size", -1)) : -1; - - qk_output_ = static_cast(info.GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))); } int num_heads_; // number of attention heads of Q @@ -46,7 +44,6 @@ class GQAAttentionBase { bool do_rotary_; // whether or not to use rotary embeddings bool rotary_interleaved_; int local_window_size_; - int qk_output_; bool use_smooth_softmax_; @@ -54,14 +51,12 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH - const T* head_sink, // Head sink for smooth softmax, nullptr if not used const Tensor* attention_bias, // Attention bias to add to QxK' const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) Tensor* output, // output tensor Tensor* present_key, // present K output tensor (if separating present KV) Tensor* present_value, // present V output tensor (if separating present KV) - Tensor* output_qk, // output QK buffer const Tensor* seqlens_k, // past sequence lengths tensor GroupQueryAttentionParameters& parameters, // attention parameters AllocatorPtr allocator, // allocator for temporary tensors @@ -69,7 +64,6 @@ class GQAAttentionBase { const bool is_prompt = parameters.is_first_prompt; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int total_sequence_length = parameters.total_sequence_length; const int head_size = parameters.head_size; const int hidden_size = parameters.hidden_size; const bool packed_qkv = parameters.is_packed_qkv; @@ -85,7 +79,8 @@ class GQAAttentionBase { // Compute the attention score. bool gqa_mlas_supported = MlasGQASupported(CblasNoTrans, CblasTrans) && MlasGQASupported(CblasNoTrans, CblasNoTrans); - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * (gqa_mlas_supported ? sizeof(T) : sizeof(float)); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * + (gqa_mlas_supported ? sizeof(T) : sizeof(float)); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -101,13 +96,11 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - T* output_qk_buffer = output_qk != nullptr ? output_qk->MutableData() : nullptr; - if (gqa_mlas_supported) { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, - batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, - seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, - past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, + head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, + tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -117,10 +110,10 @@ class GQAAttentionBase { hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); } else { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, - batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, - seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, - past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, + head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, + tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -143,19 +136,16 @@ class GQAAttentionBase { void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH - const T* head_sink, // for smooth softmax. Its size is N. const int32_t* seqlens_k, // total - 1 sequence lengths tensor const T* attention_bias, // optional attention bias const size_t batch_size, // batch size of self-attention const size_t sequence_length, // sequence length of self-attention (S) - const size_t total_sequence_length, // total sequence length (T) const gsl::span attention_bias_shape, // shape of the attention bias const size_t past_buffer_sequence_length, // sequence length of past state const size_t present_buffer_sequence_length, // sequence length of present state const size_t head_size, // head size of self-attention const T* past_key, // past key only T* present_key, // present key only - T* output_qk, // output QK buffer const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt @@ -207,11 +197,6 @@ class GQAAttentionBase { const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; U* output = attention_probs + output_offset; - T* output_qk_thread = nullptr; - if (output_qk != nullptr) { - const ptrdiff_t output_qk_offset = SafeInt(sequence_length) * total_sequence_length * (batch_index * num_heads_ + head_index); - output_qk_thread = output_qk + output_qk_offset; - } // Compute attention bias offset based on the batch and head indexes // Attention bias is of shape (B or 1, H or 1, S, T) so handle broadcasting @@ -325,6 +310,12 @@ class GQAAttentionBase { } } + if (use_smooth_softmax_) { + ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); + } else { + ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); + } + // set causal [seq_causal_length, total_seqlen) to 0.f for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { if constexpr (std::is_same::value) { @@ -334,30 +325,11 @@ class GQAAttentionBase { } } - if (qk_output_ == static_cast(QKOutputType::BEFORE_SOFTMAX)) { - WriteOutputQKHeadChunk(output_qk_thread, output_softmax, total_sequence_length); - } - - if (use_smooth_softmax_ || head_sink != nullptr) { - float sink = (head_sink != nullptr) ? static_cast(head_sink[head_index]) : 0.0f; - ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast(window_size), sink, nullptr); - } else { - ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); - } - - if (qk_output_ == static_cast(QKOutputType::AFTER_SOFTMAX)) { - WriteOutputQKHeadChunk(output_qk_thread, output_softmax, total_sequence_length); - } - output_softmax += present_buffer_sequence_length; if (attention_bias_thread != nullptr) { attention_bias_thread += attention_total_seqlen; } - - if (output_qk_thread != nullptr) { - output_qk_thread += total_sequence_length; - } } } }); @@ -483,20 +455,6 @@ class GQAAttentionBase { SafeInt(sequence_length) * batch_size * num_heads_ * head_size); } } - - template - void WriteOutputQKHeadChunk(T* output_qk, const U* attention_probs, size_t total_sequence_length) const { - if (output_qk == nullptr) { - return; - } - - if constexpr (std::is_same_v) { - std::memcpy(output_qk, attention_probs, SafeInt(total_sequence_length) * sizeof(T)); - } else { - static_assert(std::is_same_v && std::is_same_v); - MlasConvertFloatToHalfBuffer(static_cast(attention_probs), output_qk, total_sequence_length); - } - } }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index eb1560ac8e341..a912bd6e6b43c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -95,11 +95,6 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { Tensor* present_k = context->Output(1, present_k_shape); Tensor* present_v = context->Output(2, present_v_shape); - std::vector output_qk_shape{static_cast(batch_size), static_cast(num_heads_), static_cast(parameters.sequence_length), static_cast(parameters.total_sequence_length)}; - Tensor* output_qk = context->Output(3, output_qk_shape); - - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckOutputs(output_qk, qk_output_)); - AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -211,12 +206,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr; - // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, - output_qk, seqlens_k, parameters, allocator, context); + attention_bias, past_key, past_value, output, present_k, present_v, + seqlens_k, parameters, allocator, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f01ce985658aa..0f66119540b03 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -398,37 +398,6 @@ Status CheckCustomAttentionInputs(const T* position_ids, return Status::OK(); } -template -Status CheckOutputs(const T* output_qk, int qk_output) { - const bool is_valid_qk_output = qk_output == static_cast(QKOutputType::NO_OUTPUT) || - qk_output == static_cast(QKOutputType::BEFORE_SOFTMAX) || - qk_output == static_cast(QKOutputType::AFTER_SOFTMAX); - if (!is_valid_qk_output) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "qk_output attribute received unsupported value ", qk_output); - } - - if (qk_output != static_cast(QKOutputType::NO_OUTPUT) && output_qk == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "qk_output attribute was configured but output buffer was not provided"); - } - - return Status::OK(); -} - -inline Status CheckNoQKOutput(int num_outputs, int qk_output) { - if (num_outputs > 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "output_qk optional output is not supported"); - } - - if (qk_output != static_cast(QKOutputType::NO_OUTPUT)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "qk_output attribute is not supported"); - } - - return Status::OK(); -} - } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 9cb93cbcd3f32..68c4b01d2db20 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -109,12 +109,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; - // The current GQA CUDA implementation will never be able to have a QK output. - // GQA CUDA uses either flash attention or memory efficient attention. Neither kernel supports returning the QK output. - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( - context->OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); - if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 09a6550549614..85aef55908506 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -213,10 +213,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( - context->OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); - if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 1f039177b0a21..f3334b13dc645 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -178,10 +178,6 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& head_sink, params)); - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( - context.OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); - WebgpuAttentionParameters parameters(params); TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index c4667d53c0674..8ea593f107833 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -170,7 +170,7 @@ std::string CPUIDInfo::GetX86Vendor(int32_t* data) { uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { if (vendor == "GenuineIntel") return 0x8086; - if (vendor == "AuthenticAMD") return 0x1022; + if (vendor == "GenuineAMD") return 0x1022; if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); if (vendor.find("NV") == 0) return 0x10DE; return 0; diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 47fbe08da41ff..c3dd9321ebb0b 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -247,11 +247,8 @@ struct OrtNode { /// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node). ///
/// Buffer into which to copy the subgraphs. - /// Optional buffer into which to copy the attribute name for each subgraph. - /// If set, must point to a buffer with the same number of elements as `subgraphs`. /// A status indicating success or an error. - virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs, - const char** opt_attribute_names) const = 0; + virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs) const = 0; /// /// Gets the node's parent graph, which is the graph that contains this node. @@ -283,23 +280,6 @@ struct OrtGraph { /// The model's ONNX IR version. virtual int64_t GetOnnxIRVersion() const = 0; - /// - /// Gets the number of operator sets (domain, opset version) the graph's model relies on. - /// - /// Output parameter set to the number of operator sets. - /// A status indicating success or an error. - virtual onnxruntime::Status GetNumOperatorSets(size_t& num_operator_sets) const = 0; - - /// - /// Gets the operator sets the graph's model relies on. An operator set is uniquely identified by a - /// (domain, opset version) pair. - /// - /// Buffer into which to copy the domains. - /// Buffer into which to copy the opset version for each domain. - /// A status indicating success or an error. - virtual onnxruntime::Status GetOperatorSets(gsl::span domains, - gsl::span opset_versions) const = 0; - /// /// Returns the number of graph inputs, including initializers that appear in the list of graph inputs. /// diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e2b17aa84d2b1..f2757c2c96471 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -6,7 +6,6 @@ #include "core/graph/contrib_ops/quantization_defs.h" #include "core/graph/contrib_ops/onnx_function_util.h" #include "core/graph/contrib_ops/shape_inference_functions.h" -#include "contrib_ops/cpu/bert/attention_common.h" // Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from // ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build #if defined(_WIN32) && !defined(NDEBUG) @@ -233,8 +232,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c // Type and shape inference for group query attention and sparse attention. void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index = -1, - int use_max_past_present_buffer = -1, - int output_qk_index = -1) { + int use_max_past_present_buffer = -1) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); int64_t kv_sequence_length = -1; @@ -279,20 +277,13 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte } } - if (ctx.getNumOutputs() >= 3) { // has present output + if (ctx.getNumOutputs() > 1) { // has present output // copy the type from query to present key ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); // copy the type from query to present value ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); - int64_t total_sequence_length_value = 0; - const auto* total_sequence_length_data = ctx.getInputData(6); - if (total_sequence_length_data != nullptr) { - const auto& data = ParseData(total_sequence_length_data); - total_sequence_length_value = static_cast(data[0]); - } - if (past_key_index >= 0 && hasInputShape(ctx, past_key_index)) { auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); @@ -308,25 +299,30 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); } else if (use_max_past_present_buffer == 0) { if (kv_sequence_length > 0 && past_dims[2].has_dim_value()) { - const int64_t present_sequence_length = kv_sequence_length + past_dims[2].dim_value(); + int64_t total_sequence_length = kv_sequence_length + past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { *present_shape.add_dim() = dim; } - // shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size) - present_shape.mutable_dim(2)->set_dim_value(present_sequence_length); + // shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size) + present_shape.mutable_dim(2)->set_dim_value(total_sequence_length); updateOutputShape(ctx, 1, present_shape); updateOutputShape(ctx, 2, present_shape); } } else if (use_max_past_present_buffer == -1) { - if (total_sequence_length_value > 0 && past_dims[2].has_dim_value()) { + const auto* total_sequence_length_data = ctx.getInputData(6); + if (total_sequence_length_data != nullptr && past_dims[2].has_dim_value()) { + int64_t total_sequence_length_value = 0; + const auto& data = ParseData(total_sequence_length_data); + total_sequence_length_value = static_cast(data[0]); + // present_sequence_length = max(past_sequence_length, total_sequence_length) - const int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() - ? total_sequence_length_value - : past_dims[2].dim_value(); + int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() + ? total_sequence_length_value + : past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { @@ -340,50 +336,19 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte updateOutputShape(ctx, 2, present_shape); } } - - if (output_qk_index >= 0) { - const bool did_supply_qk_buffer = ctx.hasOutput(output_qk_index); - const int64_t qk_output_type = getAttribute(ctx, "qk_output", static_cast(QKOutputType::NO_OUTPUT)); - - if (qk_output_type == static_cast(QKOutputType::NO_OUTPUT) && did_supply_qk_buffer) { - fail_shape_inference("Output QK buffer was provided but qk_output attribute was not configured"); - } - - if (qk_output_type != static_cast(QKOutputType::NO_OUTPUT) && !did_supply_qk_buffer) { - fail_shape_inference("Output QK buffer was not provided but qk_output attribute was configured"); - } - - int64_t num_heads = getAttribute(ctx, "num_heads", 0); - if (did_supply_qk_buffer && hasInputShape(ctx, 0) && total_sequence_length_value > 0 && num_heads > 0) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, output_qk_index); - - auto& query_shape = getInputShape(ctx, 0); - auto& query_dims = query_shape.dim(); - - if (query_dims[0].has_dim_value() && query_dims[1].has_dim_value()) { - ONNX_NAMESPACE::TensorShapeProto output_qk_shape; - *output_qk_shape.add_dim() = query_dims[0]; // batch_size - output_qk_shape.add_dim()->set_dim_value(num_heads); // num_heads - *output_qk_shape.add_dim() = query_dims[1]; // sequence_length - output_qk_shape.add_dim()->set_dim_value(total_sequence_length_value); // total_sequence_length - updateOutputShape(ctx, output_qk_index, output_qk_shape); - } - } - } } } } -void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index, int qk_output_index) { +void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not constexpr int use_max_past_present_buffer = -1; - BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer, qk_output_index); + BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); } void SparseAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { constexpr int use_max_past_present_buffer = 1; - constexpr int qk_output_index = -1; - BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer, qk_output_index); + BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); } constexpr const char* Attention_ver1_doc = R"DOC( @@ -1162,10 +1127,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Use a smooth factor in softmax.", AttributeProto::INT, static_cast(-1)) - .Attr("qk_output", - "Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).", - AttributeProto::INT, - static_cast(QKOutputType::NO_OUTPUT)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" @@ -1223,11 +1184,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) - .Input(11, - "head_sink", - "1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.", - "T", - OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", @@ -1244,15 +1200,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T") - .Output(3, - "output_qk", - "Values of QK matrix multiplication, either before or after softmax normalization", - "T", - OpSchema::Optional) .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - GroupQueryAttentionTypeAndShapeInference(ctx, 3, 3); + GroupQueryAttentionTypeAndShapeInference(ctx, 3); })); constexpr const char* PagedAttention_ver1_doc = R"DOC( diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index f57543416a68f..698c7422a1e2a 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -129,12 +129,11 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_implicit_inputs, ep_node_implicit_inputs); - std::unordered_map> subgraphs_map = node.GetAttributeNameToSubgraphMap(); - ep_node_subgraphs.reserve(subgraphs_map.size()); + std::vector> node_subgraphs = node.GetSubgraphs(); + ep_node_subgraphs.reserve(node_subgraphs.size()); - for (const auto& [attr_name, subgraph] : subgraphs_map) { + for (gsl::not_null subgraph : node_subgraphs) { SubgraphState subgraph_state; - subgraph_state.attribute_name = attr_name; subgraph_state.subgraph_viewer = std::make_unique(*subgraph); ORT_RETURN_IF_ERROR(EpGraph::Create(*subgraph_state.subgraph_viewer, subgraph_state.ep_subgraph)); subgraph_state.ep_subgraph->SetParentNode(ep_node.get()); @@ -234,17 +233,12 @@ Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { return Status::OK(); } -Status EpNode::GetSubgraphs(gsl::span subgraphs, - const char** opt_attribute_names) const { +Status EpNode::GetSubgraphs(gsl::span dst) const { const size_t num_subgraphs = subgraphs_.size(); - ORT_RETURN_IF_ERROR((CheckCopyDestination("node subgraphs", num_subgraphs, subgraphs))); + ORT_RETURN_IF_ERROR((CheckCopyDestination("node attributes", num_subgraphs, dst))); for (size_t i = 0; i < num_subgraphs; ++i) { - subgraphs[i] = subgraphs_[i].ep_subgraph.get(); - - if (opt_attribute_names) { - opt_attribute_names[i] = subgraphs_[i].attribute_name.c_str(); - } + dst[i] = subgraphs_[i].ep_subgraph.get(); } return Status::OK(); @@ -276,10 +270,6 @@ const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { } } -const std::string& EpNode::GetEpName() const { - return node_.GetExecutionProviderType(); -} - // // EpValueInfo // @@ -509,34 +499,10 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node) EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag) : OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {} -EpGraph::EpGraph(std::unique_ptr graph_viewer, - std::unique_ptr indexed_sub_graph, - PrivateTag) - : OrtGraph(OrtGraphIrApi::kEpApi), - graph_viewer_(*graph_viewer.get()), - owned_graph_viewer_(std::move(graph_viewer)), - owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {} - // Static class function to create a std::unique_ptr. Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { auto ep_graph = std::make_unique(graph_viewer, PrivateTag{}); - return CreateImpl(std::move(ep_graph), graph_viewer, result); -} - -// Static class function to create a std::unique_ptr. -Status EpGraph::Create(std::unique_ptr src_graph_viewer, - std::unique_ptr src_indexed_sub_graph, - /*out*/ std::unique_ptr& result) { - auto& graph_viewer = *src_graph_viewer.get(); - auto ep_graph = std::make_unique(std::move(src_graph_viewer), - std::move(src_indexed_sub_graph), - PrivateTag{}); - - return CreateImpl(std::move(ep_graph), graph_viewer, result); -} - -Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance(); std::unordered_map> value_infos_map; @@ -694,43 +660,6 @@ const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); } int64_t EpGraph::GetOnnxIRVersion() const { return graph_viewer_.GetOnnxIRVersion(); } -Status EpGraph::GetNumOperatorSets(size_t& num_operator_sets) const { - num_operator_sets = graph_viewer_.DomainToVersionMap().size(); - return Status::OK(); -} - -Status EpGraph::GetOperatorSets(gsl::span domains, - gsl::span opset_versions) const { - const std::unordered_map& domain_to_version = graph_viewer_.DomainToVersionMap(); - size_t num_operator_sets = domain_to_version.size(); - - ORT_RETURN_IF_ERROR((CheckCopyDestination("operator set domains", num_operator_sets, domains))); - ORT_RETURN_IF_ERROR((CheckCopyDestination("operator set versions", num_operator_sets, opset_versions))); - - // Collect (domain, version) pairs and sort them by domain to ensure user always gets a stable ordering. - std::vector> pairs; - pairs.reserve(num_operator_sets); - - for (const auto& [domain, version] : domain_to_version) { - pairs.emplace_back(domain.c_str(), version); - } - - std::sort(pairs.begin(), pairs.end(), - [](const std::pair& a, const std::pair& b) -> bool { - return std::strcmp(a.first, b.first) < 0; - }); - - // Copy sorted (domain, version) pairs into the destination buffers. - size_t index = 0; - for (const auto& [domain_c_str, version] : pairs) { - domains[index] = domain_c_str; - opset_versions[index] = version; - index++; - } - - return Status::OK(); -} - size_t EpGraph::GetNumInputs() const { return inputs_.size(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index d3921e051e18a..4240f5636b7ae 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -111,7 +111,6 @@ struct EpNode : public OrtNode { struct SubgraphState { SubgraphState() = default; SubgraphState(SubgraphState&& other) = default; - std::string attribute_name; std::unique_ptr subgraph_viewer; // The graph_viewer wrapped by EpGraph below. std::unique_ptr ep_subgraph; }; @@ -183,8 +182,7 @@ struct EpNode : public OrtNode { Status GetNumSubgraphs(size_t& num_subgraphs) const override; // Gets the subgraphs contained by this node. - Status GetSubgraphs(gsl::span subgraphs, - const char** opt_attribute_names) const override; + Status GetSubgraphs(gsl::span subgraphs) const override; // Gets this node's parent graph, which is the graph that directly contains this node. Status GetGraph(const OrtGraph*& parent_graph) const override; @@ -208,9 +206,6 @@ struct EpNode : public OrtNode { // Helper that gets the node's attributes by name. const OrtOpAttr* GetAttribute(const std::string& name) const; - // Helper that gets the execution provider name that this node is assigned to run on. - const std::string& GetEpName() const; - private: // Back pointer to containing graph. Useful when traversing through nested subgraphs. // Will be nullptr if the EpNode was created without an owning graph. @@ -254,32 +249,15 @@ struct EpGraph : public OrtGraph { public: EpGraph(const GraphViewer& graph_viewer, PrivateTag); - EpGraph(std::unique_ptr graph_viewer, - std::unique_ptr indexed_sub_graph, - PrivateTag); /// /// Creates an instance of EpGraph, which wraps a GraphViewer. - /// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph. /// /// /// /// static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); - /// - /// Creates an instance of EpGraph, which wraps a GraphViewer. - /// This call is used when creating an EpGraph from a subset of nodes in another EpGraph. - /// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance - /// must take ownership of both the GraphViewer and IndexedSubGraph. - /// - /// - /// - /// - static Status Create(std::unique_ptr graph_viewer, - std::unique_ptr indexed_sub_graph, - /*out*/ std::unique_ptr& result); - // Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph. DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi) @@ -293,14 +271,6 @@ struct EpGraph : public OrtGraph { // Returns the model's ONNX IR version. int64_t GetOnnxIRVersion() const override; - // Gets the number of operator sets that the graph's model uses. - Status GetNumOperatorSets(size_t& num_operator_sets) const override; - - // Gets the operator sets that the graph's model uses. An operator set is uniquely identified by a - // (domain, opset version) pair. - Status GetOperatorSets(gsl::span domains, - gsl::span opset_versions) const override; - // Get the number of graph inputs, including initializers that are listed as graph inputs. size_t GetNumInputs() const override; @@ -351,22 +321,9 @@ struct EpGraph : public OrtGraph { const OrtValue* GetInitializerValue(std::string_view name) const; private: - /// - /// The real implementation of creating an EpGraph instance. - /// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly. - /// - /// - /// - /// - /// - static Status CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); - const GraphViewer& graph_viewer_; const EpNode* parent_node_ = nullptr; - std::unique_ptr owned_graph_viewer_ = nullptr; - std::unique_ptr owned_indexed_sub_graph_ = nullptr; - std::vector> nodes_; IndexToEpNodeMap index_to_ep_node_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 4d3091520d876..ca40bad2b4250 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,10 +1818,6 @@ NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name return node_arg; } -const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const { - return const_cast(this)->GetNodeArgIncludingParentGraphs(node_arg_name); -} - void Graph::ReverseDFSFrom(gsl::span from, const std::function& enter, const std::function& leave, diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 948ebaa5f7e15..1842c2b4a0d1f 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -168,15 +168,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) filtered_node_inputs_including_initializers_.reserve(metadef->inputs.size()); for (const auto& input : metadef->inputs) { - // NodeArgs from the current scope or any outer scopes should be handled correctly. - // - // There is an edge case where the model consists of a graph with subgraphs nested across three levels. - // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). - // When constructing a new GraphViewer for the second- and third-layer subgraphs, - // the second-layer graph may not have the corresponding value_info for that first-layer input, - // because the second-layer graph itself doesn't consume it. - // Therefore, when working within the second-layer graph, we need to search outer scopes for the missing value_info. - const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(input); + const auto* nodearg = graph.GetNodeArg(input); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Input not found:", input); filtered_node_inputs_including_initializers_.push_back(nodearg); if (!graph.IsInitializedTensor(input)) { @@ -185,7 +177,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } for (const auto& output : metadef->outputs) { - const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(output); + const auto* nodearg = graph.GetNodeArg(output); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Output not found:", output); filtered_node_outputs_.push_back(nodearg); } diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 6e7e17374bb59..6330a42c115db 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -136,8 +136,7 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } - Status GetSubgraphs(gsl::span /*subgraphs*/, - const char** /*opt_attribute_names*/) const override { + Status GetSubgraphs(gsl::span /*subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } @@ -177,17 +176,6 @@ struct ModelEditorGraph : public OrtGraph { return ONNX_NAMESPACE::Version::IR_VERSION; } - Status GetNumOperatorSets(size_t& /*num_operator_sets*/) const override { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "OrtModelEditorApi does not support getting the graph's operator sets."); - } - - Status GetOperatorSets(gsl::span /*domains*/, - gsl::span /*opset_versions*/) const override { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "OrtModelEditorApi does not support getting the graph's operator sets."); - } - size_t GetNumInputs() const override { return inputs.size(); } Status GetInputs(gsl::span /*result*/) const override { diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 4d85c35461825..3575e30721af7 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1020,7 +1020,6 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, - float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1224,21 +1223,6 @@ MlasQuantizeLinearS4( int8_t ZeroPoint ); -// -// Linear dequantization routines. -// - -template -void -MLASCALL -MlasDequantizeLinear( - const InputType* Input, - float* Output, - size_t N, - float Scale, - InputType ZeroPoint - ); - /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index 669c73d2b9c06..96a2398796777 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -74,7 +74,6 @@ struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; bool SmoothSoftmax; - float Sink; const T* Input; T* Output; size_t N; @@ -851,7 +850,6 @@ Return Value: const size_t D = WorkBlock->D; const bool LogSoftmax = WorkBlock->LogSoftmax; const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; - const float Sink = WorkBlock->Sink; const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; @@ -882,11 +880,10 @@ Return Value: #else float Maximum = MlasReduceMaximumF32Kernel(Input, D); #endif - if (SmoothSoftmax && Sink > Maximum) { - Maximum = Sink; - } - float NegativeMaximum = -Maximum; + if (SmoothSoftmax && NegativeMaximum > 0.0f) { + NegativeMaximum = 0.0f; + } // // Compute the exponential function for each element of the row (save to Temp if provided) and @@ -900,7 +897,7 @@ Return Value: #endif if (SmoothSoftmax) { - Accumulation += expf(Sink + NegativeMaximum); + Accumulation += expf(NegativeMaximum); } if (LogSoftmax) { @@ -1017,7 +1014,6 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, - float Sink, MLAS_THREADPOOL* ThreadPool ) /*++ @@ -1043,8 +1039,6 @@ Routine Description: SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. - Sink - Supplies the smooth factor to use in the softmax operation. - ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -1066,7 +1060,6 @@ Return Value: WorkBlock.Output = Output; WorkBlock.N = N; WorkBlock.D = D; - WorkBlock.Sink = Sink; // // Compute the number of target threads given the complexity of the softmax @@ -1104,7 +1097,6 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, - float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1118,7 +1110,6 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, - float Sink, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/dequantize.cpp b/onnxruntime/core/mlas/lib/dequantize.cpp deleted file mode 100644 index 175d3f668ac39..0000000000000 --- a/onnxruntime/core/mlas/lib/dequantize.cpp +++ /dev/null @@ -1,395 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - dequantize.cpp - -Abstract: - - This module implements routines to dequantize buffers. - - The dequantization formula as specified in the ONNX operator documentation is: - - Output = (Input - ZeroPoint) * Scale - ---*/ - -#include "mlasi.h" - -// -// DequantizeLinear reference implementation using the C++ runtime. -// - -template -static -MLAS_FORCEINLINE -void -MlasDequantizeLinearRefImpl( - const InputType* Input, - float* Output, - size_t N, - float Scale, - InputType ZeroPoint - ) -/*++ - -Routine Description: - - This routine quantizes the input buffer using the supplied quantization - parameters. - -Arguments: - - Input - Supplies the input buffer with quantized data. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - - Scale - Supplies the quantization scale. - - ZeroPoint - Supplies the quantization zero point value. - -Return Value: - - None. - ---*/ -{ - int32_t ZeroPointS32 = static_cast(ZeroPoint); - - for (size_t n = 0; n < N; n++) { - Output[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; - } -} - -#if defined(MLAS_SSE2_INTRINSICS) -// Implementation for Intel SSE 2. Refer to the Intel Intrisics Guide: -// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html - -void -MLASCALL -MlasDequantizeLinearS8Kernel( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); - const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s - const __m128i Zeros = _mm_setzero_si128(); - - while (N >= 16) { - // Load a vector of 16 int8s: [0 ... 15] - __m128i VectorS8 = _mm_loadu_si128(reinterpret_cast(Input)); - - // Sign-extend into 2 vectors of 8 int16s - __m128i SignMaskS8 = _mm_cmpgt_epi8(Zeros, VectorS8); // 0xFF for every negative byte in VectorS8 - __m128i VectorS16_0 = _mm_unpacklo_epi8(VectorS8, SignMaskS8); // [0 ... 7] - __m128i VectorS16_1 = _mm_unpackhi_epi8(VectorS8, SignMaskS8); // [8 ... 15] - - // Subtract the zero-points in int16 domain. - VectorS16_0 = _mm_sub_epi16(VectorS16_0, ZeroPointS16Vector); - VectorS16_1 = _mm_sub_epi16(VectorS16_1, ZeroPointS16Vector); - - // Sign-extend into 4 vectors of 4 int32s - __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); - __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] - __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] - - __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); - __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] - __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] - - // Cast each int32x4 to float and multiply by the scale vector. - __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); - __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); - __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); - __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); - - // Store each int32x4 into the output. - _mm_storeu_ps(Output + 0, VectorF32_0); - _mm_storeu_ps(Output + 4, VectorF32_1); - _mm_storeu_ps(Output + 8, VectorF32_2); - _mm_storeu_ps(Output + 12, VectorF32_3); - - Input += 16; - Output += 16; - N -= 16; - } - - // Handle leftover elements (< 16) with the scalar reference implementation. - MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasDequantizeLinearU8Kernel( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ - const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); - const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s - const __m128i Zeros = _mm_setzero_si128(); - - while (N >= 16) { - // Load a vector of 16 uint8s: [0 ... 15] - __m128i VectorU8 = _mm_loadu_si128(reinterpret_cast(Input)); - - // Zero-extend into 2 vectors of 8 uint16s - __m128i VectorU16_0 = _mm_unpacklo_epi8(VectorU8, Zeros); // [0 ... 7] - __m128i VectorU16_1 = _mm_unpackhi_epi8(VectorU8, Zeros); // [8 ... 15] - - // Subtract the zero-points as uint16s. Due to two's compliment, negative results can be reinterpreted as int16 - __m128i VectorS16_0 = _mm_sub_epi16(VectorU16_0, ZeroPointS16Vector); - __m128i VectorS16_1 = _mm_sub_epi16(VectorU16_1, ZeroPointS16Vector); - - // Sign-extend into 4 vectors of 4 int32s - __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); - __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] - __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] - - __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); - __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] - __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] - - // Cast each int32x4 to float and multiply by the scale vector. - __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); - __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); - __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); - __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); - - // Store each int32x4 into the output. - _mm_storeu_ps(Output + 0, VectorF32_0); - _mm_storeu_ps(Output + 4, VectorF32_1); - _mm_storeu_ps(Output + 8, VectorF32_2); - _mm_storeu_ps(Output + 12, VectorF32_3); - - Input += 16; - Output += 16; - N -= 16; - } - - // Handle leftover elements (< 16) with the scalar reference implementation. - MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasDequantizeLinear( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().DequantizeLinearS8Kernel( -#else - MlasDequantizeLinearS8Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasDequantizeLinear( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().DequantizeLinearU8Kernel( -#else - MlasDequantizeLinearU8Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); -} -#elif defined(MLAS_NEON64_INTRINSICS) -// Implementation for ARM64 NEON. Refer to the ARM instrinsics guide: -// https://developer.arm.com/architectures/instruction-sets/intrinsics/ - -void -MLASCALL -MlasDequantizeLinearS8Kernel( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); - const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint); // Broadcast ZeroPoint (sign-extended to 16bits) - - while (N >= 16) { - // Load a vector of 16 int8s: [0 ... 15] - int8x16_t VectorS8 = vld1q_s8(Input); - - // Sign-extend into 2 vectors of 8 int16s - int16x8_t VectorS16_0 = vmovl_s8(vget_low_s8(VectorS8)); // [0 ... 7] - int16x8_t VectorS16_1 = vmovl_s8(vget_high_s8(VectorS8)); // [8 ... 15] - - // Subtract the zero-points in int16 domain. - VectorS16_0 = vsubq_s16(VectorS16_0, ZeroPointVector); - VectorS16_1 = vsubq_s16(VectorS16_1, ZeroPointVector); - - // Sign-extend into 4 vectors of 4 int32s - int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] - int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] - int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] - int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] - - // Cast each int32x4 to float and multiply by the scale vector. - float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); - float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); - float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); - float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); - - // Store each int32x4 into the output. - vst1q_f32(Output + 0, VectorF32_0); - vst1q_f32(Output + 4, VectorF32_1); - vst1q_f32(Output + 8, VectorF32_2); - vst1q_f32(Output + 12, VectorF32_3); - - N -= 16; - Input += 16; - Output += 16; - } - - // Handle leftover elements (< 16) with the scalar reference implementation. - MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasDequantizeLinearU8Kernel( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ - const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); - const uint8x8_t ZeroPointVector = vdup_n_u8(ZeroPoint); // Broadcast ZeroPoint to 8 uint8s - - while (N >= 16) { - // Load a vector of 16 uint8s: [0 ... 15] - uint8x16_t VectorU8 = vld1q_u8(Input); - - // Subtract zero-point. The vsubl_u8 instruction zero-extends its arguments to uint16 first. - // The reinterpret from uint16x8 to int16x8 is actually a NOP. - int16x8_t VectorS16_0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(VectorU8), ZeroPointVector)); // [0 ... 7] - int16x8_t VectorS16_1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(VectorU8), ZeroPointVector)); // [8 ... 15] - - // Sign-extend into 4 vectors of 4 int32s - int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] - int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] - int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] - int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] - - // Cast each int32x4 to float and multiply by the scale vector. - float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); - float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); - float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); - float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); - - // Store each int32x4 into the output. - vst1q_f32(Output + 0, VectorF32_0); - vst1q_f32(Output + 4, VectorF32_1); - vst1q_f32(Output + 8, VectorF32_2); - vst1q_f32(Output + 12, VectorF32_3); - - N -= 16; - Input += 16; - Output += 16; - } - - // Handle leftover elements (< 16) with the scalar reference implementation. - MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasDequantizeLinear( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasDequantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasDequantizeLinear( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ - MlasDequantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); -} -#else -// Implementation that uses the scalar reference implementation. - -template -void -MLASCALL -MlasDequantizeLinear( - const InputType* Input, - float* Output, - size_t N, - float Scale, - InputType ZeroPoint - ) -{ - MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); -} - -template -void -MLASCALL -MlasDequantizeLinear( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ); - -template -void -MLASCALL -MlasDequantizeLinear( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ); - -#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0879d1b0ba510..0af3cd2e33b02 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -747,24 +747,6 @@ void float Scale, int8_t ZeroPoint); -typedef -void -(MLASCALL MLAS_DEQUANTIZE_LINEAR_U8_KERNEL)( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint); - -typedef -void -(MLASCALL MLAS_DEQUANTIZE_LINEAR_S8_KERNEL)( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint); - template struct MLAS_QUANT_KERNEL { @@ -921,8 +903,6 @@ extern "C" { MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) - MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; - MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelAvx512F; @@ -1266,8 +1246,6 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; - MLAS_DEQUANTIZE_LINEAR_S8_KERNEL* DequantizeLinearS8Kernel; - MLAS_DEQUANTIZE_LINEAR_U8_KERNEL* DequantizeLinearU8Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 45bba5363d4f2..45d3a876beb86 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -285,8 +285,6 @@ Return Value: this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; - this->DequantizeLinearS8Kernel = MlasDequantizeLinearS8Kernel; - this->DequantizeLinearU8Kernel = MlasDequantizeLinearU8Kernel; #ifndef __APPLE__ #ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index fa645939a6395..dcc030cb3467d 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -89,10 +89,23 @@ uint64_t GetLuidKey(LUID luid) { return (uint64_t(luid.HighPart) << 32) | luid.LowPart; } +// Converts a wide string (up to 4 characters) representing a hardware ID component (e.g., "ABCD" from "VEN_ABCD") +// into a uint32_t. The conversion is done in a little-endian manner, meaning the first character +// of the string becomes the least significant byte of the integer, and the fourth character +// becomes the most significant byte. +uint32_t WStringToUint32Id(const std::wstring& vendor_name) { + uint32_t vendor_id = 0; + for (size_t i = 0; i < 4 && i < vendor_name.size(); ++i) { + // For little-endian, place each character at the appropriate byte position + // First character goes into lowest byte, last character into highest byte + vendor_id |= static_cast(vendor_name[i] & 0xFF) << (i * 8); + } + return vendor_id; +} + // returns info for display and processor entries. key is (vendor_id << 32 | device_id) // npus: (vendor_id << 32 | device_id) for devices we think are NPUs from DXCORE -std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus, - bool& have_remote_display_adapter) { +std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus) { std::unordered_map device_info; const GUID local_DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML = {0xb71b0d41, 0x1088, 0x422f, 0xa2, 0x7c, 0x2, 0x50, 0xb7, 0xd3, 0xa9, 0x88}; @@ -138,7 +151,8 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde if (auto idx = hardware_id.find(prefix); idx != std::wstring::npos) { auto id = hardware_id.substr(idx + prefix.size(), 4); if (id.size() == 4) { - return static_cast(std::stoul(id, nullptr, 16)); + // DXCore reports vendor and device IDs as 32-bit integer representations of the ASCII string. + return WStringToUint32Id(id); } } @@ -156,11 +170,6 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde // Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose. if (vendor_id == 0 && device_id == 0) { - static const std::wstring remote_display_adapter_id(L"RdpIdd_IndirectDisplay"); - if (guid == GUID_DEVCLASS_DISPLAY && remote_display_adapter_id == buffer) { - have_remote_display_adapter = true; - } - continue; } @@ -296,7 +305,7 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde } // returns LUID to DeviceInfo -std::unordered_map GetDeviceInfoD3D12(bool have_remote_display_adapter) { +std::unordered_map GetDeviceInfoD3D12() { std::unordered_map device_info; ComPtr factory; @@ -305,8 +314,6 @@ std::unordered_map GetDeviceInfoD3D12(bool have_remote_dis return device_info; } - UINT num_adapters = 0; - ComPtr adapter; for (UINT i = 0; factory->EnumAdapters1(i, adapter.ReleaseAndGetAddressOf()) != DXGI_ERROR_NOT_FOUND; ++i) { DXGI_ADAPTER_DESC1 desc; @@ -332,12 +339,9 @@ std::unordered_map GetDeviceInfoD3D12(bool have_remote_dis info.metadata[L"LUID"] = std::to_wstring(key); info.metadata[L"DxgiAdapterNumber"] = std::to_wstring(i); info.metadata[L"DxgiVideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; - - ++num_adapters; } - // iterate by high-performance GPU preference to add that info. - UINT cur_adapter = 0; + // iterate by high-performance GPU preference to add that info for (UINT i = 0; factory->EnumAdapterByGpuPreference( i, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, IID_PPV_ARGS(adapter.ReleaseAndGetAddressOf())) != DXGI_ERROR_NOT_FOUND; @@ -348,41 +352,12 @@ std::unordered_map GetDeviceInfoD3D12(bool have_remote_dis } uint64_t key = GetLuidKey(desc.AdapterLuid); - auto it = device_info.find(key); - if (it == device_info.end()) { - continue; - } - DeviceInfo& info = it->second; - - // try and drop the Microsoft Remote Display Adapter. it does not have the DXGI_ADAPTER_FLAG_SOFTWARE flag set - // and the vendor id, device id and description are the same as the real device. the LUID is different to the real - // device. - // Assumption: it will have the worst performance index of the devices we're considering so we only check the - // last adapter - if (num_adapters > 1 && have_remote_display_adapter && cur_adapter == num_adapters - 1) { - ComPtr output; - if (adapter->EnumOutputs(0, &output) == DXGI_ERROR_NOT_FOUND) { - // D3D_DRIVER_TYPE_WARP. Software based or disabled adapter. - // An adapter can be disabled in an RDP session. e.g. integrated GPU is disabled if there's a discrete GPU - - // if we have seen this vendor_id+device_id combination with a different LUID before we drop it. - if (std::any_of(device_info.begin(), device_info.end(), - [key, &info](const auto& entry) { - const auto& entry_info = entry.second; - return key != entry.first && - info.vendor_id == entry_info.vendor_id && - info.device_id == entry_info.device_id; - })) { - device_info.erase(key); - continue; - } - } + auto it = device_info.find(key); + if (it != device_info.end()) { + DeviceInfo& info = it->second; + info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); } - - info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); - - ++cur_adapter; } return device_info; @@ -522,12 +497,10 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } } - // setupapi_info. key is vendor_id+device_id - bool have_remote_display_adapter = false; // set if we see the RdpIdd_IndirectDisplay hardware ID. - std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus, have_remote_display_adapter); - // d3d12 info. key is luid - std::unordered_map luid_to_d3d12_info = GetDeviceInfoD3D12(have_remote_display_adapter); + std::unordered_map luid_to_d3d12_info = GetDeviceInfoD3D12(); + // setupapi_info. key is vendor_id+device_id + std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus); // Ensure we have at least one CPU bool found_cpu = false; diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index e123414b03b21..2817dda9d0085 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -99,7 +99,7 @@ common::Status SoftmaxCPU(size_t N, float* Ydata, bool logarithmic, onnxruntime::concurrency::ThreadPool* thread_pool) { - MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, 0.0f, thread_pool); + MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index f7cc2523adbf6..3359b2a69fe83 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -445,7 +445,7 @@ void batched_update_scores_inplace(gsl::span scores, int64_t num_batches_in, } if (use_mlas) { - MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, 0.0f, threadpool); + MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, threadpool); } else { while (s < s_end) { gsl::span scores_for_batch(s, s + batch_size); diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index c691be6ffd0e8..adb2aee171f39 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" @@ -302,31 +301,14 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, const T* input, - const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + const OutT* scale, OutT* output, const T* zero_point) { for (size_t m = 0; m < M; m++) { for (size_t k = 0; k < K; k++) { -#if defined(ORT_CLIENT_PACKAGE_BUILD) - // TODO: Only using multithreaded/SIMD DQ when ORT is built for client/on-device workloads. - // Make this the default behavior after more testing. - if constexpr (std::is_same_v || std::is_same_v) { - ParDequantizeLinearStd(input, output, N, scale[k], zero_point ? zero_point[k] : 0, thread_pool); - input += N; - output += N; - } else { - auto zp = zero_point ? static_cast(zero_point[k]) : 0; - auto sc = static_cast(scale[k]); - for (size_t n = 0; n < N; n++) { - *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); - } - } -#else - ORT_UNUSED_PARAMETER(thread_pool); auto zp = zero_point ? static_cast(zero_point[k]) : 0; auto sc = static_cast(scale[k]); for (size_t n = 0; n < N; n++) { *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); } -#endif // defined(ORT_CLIENT_PACKAGE_BUILD) } } } @@ -345,8 +327,7 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { - ORT_UNUSED_PARAMETER(thread_pool); + const T* input, const OutT* scale, OutT* output, const T* zero_point) { if (zero_point) { for (size_t m = 0; m < M; m++) { for (size_t bd = 0; bd < K; bd += quant_block_size) { @@ -387,8 +368,7 @@ template struct DequantizeLinearApply { // per-tensor/layer or per-axis quantization void op(size_t M, size_t K, size_t N, - const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { - ORT_UNUSED_PARAMETER(thread_pool); + const T* input, const OutT* scale, OutT* output, const T* zero_point) { size_t input_index = 0; for (size_t m = 0; m < M; m++) { @@ -414,8 +394,7 @@ struct DequantizeLinearApply { // Blocked quantization // TODO(fajin) : add mlas kernel to utilize multithreading, refer MlasDequantizeBlockwise. void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { - ORT_UNUSED_PARAMETER(thread_pool); + const T* input, const OutT* scale, OutT* output, const T* zero_point) { size_t input_index = 0; if (zero_point) { @@ -461,36 +440,36 @@ struct DequantizeLinearApply { #if !defined(DISABLE_FLOAT8_TYPES) -#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ - template \ - struct DequantizeLinearApply { \ - /* Per-tensor/layer or per-axis quantization */ \ - void op(size_t M, size_t K, size_t N, \ - const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd++) { \ - auto sc = scale[bd]; \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - } \ - } \ - /* Blocked quantization */ \ - void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ - const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd += quant_block_size) { \ - for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - auto sc = static_cast(scale[bs]); \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - scale += N; \ - } \ - } \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ + template \ + struct DequantizeLinearApply { \ + /* Per-tensor/layer or per-axis quantization */ \ + void op(size_t M, size_t K, size_t N, \ + const T* input, const OutT* scale, OutT* output, const T*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd++) { \ + auto sc = scale[bd]; \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + } \ + } \ + /* Blocked quantization */ \ + void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ + const T* input, const OutT* scale, OutT* output, const T*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd += quant_block_size) { \ + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + auto sc = static_cast(scale[bs]); \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + scale += N; \ + } \ + } \ + } \ }; DEQUANTIZE_LINEAR_APPLY_FLOAT8(Float8E4M3FN) @@ -534,7 +513,6 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const auto to = x_scale.GetElementType(); const T* input = x.Data(); constexpr bool is_4bit = boost::mp11::mp_contains, T>::value; - concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); if (to == ONNX_NAMESPACE::TensorProto::FLOAT) { const float* scale = x_scale.Data(); @@ -544,12 +522,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point, thread_pool); + input, scale, output, zero_point); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point, thread_pool); + input, scale, output, zero_point); } } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); @@ -559,12 +537,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point, thread_pool); + input, scale, output, zero_point); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point, thread_pool); + input, scale, output, zero_point); } } else if (to == ONNX_NAMESPACE::TensorProto::BFLOAT16) { ORT_THROW("DequantizeLinear into BFLOAT16 is not implemented yet."); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index f00bf51ae143d..2de496a9168a0 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -313,10 +313,8 @@ CUDA_Provider* GetProvider() { // OrtEpApi infrastructure to be able to use the CUDA EP as an OrtEpFactory for auto EP selection. struct CudaEpFactory : OrtEpFactory { CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { - ort_version_supported = ORT_API_VERSION; GetName = GetNameImpl; GetVendor = GetVendorImpl; - GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -333,11 +331,6 @@ struct CudaEpFactory : OrtEpFactory { return factory->vendor.c_str(); } - static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->vendor_id; - } - static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { return ORT_VERSION; } @@ -381,7 +374,6 @@ struct CudaEpFactory : OrtEpFactory { const OrtApi& ort_api; const std::string ep_name{kCudaExecutionProvider}; // EP name const std::string vendor{"Microsoft"}; // EP vendor name - uint32_t vendor_id{0x1414}; // Microsoft vendor ID }; extern "C" { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 9611cb82d5a62..a5066a41981e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -781,10 +781,7 @@ namespace Dml // this branch could be reached with a bad custom operator or malformed file. If // a legitimate case reaches here and DML needs to support a new input/output type // besides tensors, then remove the assert. - - // If the model has nodes that use Optional we will arrive here. It's a valid ONNX model but - // TryGetTensorDataType doesn't handle Optional. - // assert(false); + assert(false); nodeContainsSupportedDataTypes = false; return; } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index c5b6507ac847b..711d81186bad1 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1304,7 +1304,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(CUDA_PINNED, device_id); + return std::make_unique(device_id, CUDA_PINNED); }, narrow(device_id_)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 21947a22e2b92..86b684f8c6ebd 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); bool needs_reshape = false; - const std::string reshape_prior_out = input_names[0] + "_prior_reshape"; + const std::string reshape4d = input_names[0] + "_pre_reshape"; if (input_shape.size() == 3) { needs_reshape = true; // build new_shape = {N, 1, C, L} @@ -245,24 +245,25 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra input_shape[1], input_shape[2]}; - QnnTensorWrapper reshape_prior_tensor( - reshape_prior_out, + const std::string reshape_node_name = "pre_reshape"; + QnnTensorWrapper rw( + reshape4d, QNN_TENSOR_TYPE_NATIVE, reshape_input_info.qnn_data_type, reshape_input_info.quant_param.Copy(), std::move(new_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)), - "Failed to add reshape prior tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)), + "Failed to add reshape-4d tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - utils::GetNodeName(node_unit) + "_reshape_prior", + reshape_node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_RESHAPE, + "Reshape", {input_names[0]}, - {reshape_prior_out}, + {reshape4d}, {}, do_op_validation), - "Failed to create reshape prior node for pool op."); - input_names[0] = reshape_prior_out; + "Failed to create reshape-4d node."); + input_names[0] = reshape4d; input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]}; } @@ -445,7 +446,9 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); - const std::string pool_out = real_out + "_reshape_after"; + const std::string pool_name = "poolmax2d"; + const std::string pool_out = real_out + "_post_reshape"; + const std::string post_reshape_node_name = "post_reshape"; const std::string qnn_op = GetQnnOpType(op_type); TensorInfo output_info{}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); @@ -463,34 +466,33 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra "Failed to add tensor for pool_out"); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - utils::GetNodeName(node_unit) + "_pool2d", + pool_name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op, - {reshape_prior_out}, + {reshape4d}, {pool_out}, std::move(param_tensor_names), do_op_validation), - "Failed to create pool node for rank-3 input."); + "Failed to create QNN Pool node for rank-3 input."); std::vector final_shape3d = output_info.shape; - QnnTensorWrapper reshape_after_tensor( + QnnTensorWrapper reshape_back_tensor( real_out, tensor_type, output_info.qnn_data_type, output_info.quant_param.Copy(), std::move(final_shape3d)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)), - "Failed to add reshape after tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - utils::GetNodeName(node_unit) + "_reshape_after", + post_reshape_node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_RESHAPE, + "Reshape", {pool_out}, {real_out}, {}, do_op_validation), - "Failed to create reshape after node for pool op."); + "Failed to create reshape-back node."); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 502ea86b689f4..2650316dd07ac 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace qnn { -// Operator which only need to handle node inputs & outputs, no attributes or no need to handle attributes +// Operator which only need to hanle node inputs & outputs, no attributes or no need to handle attributes class SimpleOpBuilder : public BaseOpBuilder { public: SimpleOpBuilder() : BaseOpBuilder("SimpleOpBuilder") {} @@ -38,7 +38,7 @@ class SimpleOpBuilder : public BaseOpBuilder { const logging::Logger& logger, bool do_op_validation) const ORT_MUST_USE_RESULT; - static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest", "linear"}; + static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; static constexpr std::array scatternd_supported_reduction = {"none", "add", "mul"}; }; @@ -60,8 +60,8 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, // To DO: Remove once QNN CPU supports ScatterND const auto qnn_backend_type = qnn_model_wrapper.GetQnnBackendType(); if (op_type == "ScatterND") { - ORT_RETURN_IF(qnn_backend_type == QnnBackendType::CPU, - "QNN EP does not support ScatterND op on CPU backend. Falling back to ORT CPU."); + ORT_RETURN_IF_NOT(qnn_backend_type == QnnBackendType::HTP, + "QNN EP only supports ScatterND op on HTP backend. Falling back to ORT CPU."); } // ONNX's Min, Max, and Sum operators accept a variable number of inputs (i.e., variadic). @@ -233,12 +233,12 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, std::string mode = node_helper.Get("mode", "linear"); Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT; mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; - if ("linear" == mode || "bilinear" == mode) { + if ("bilinear" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_BILINEAR; } else if ("nearest" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_NEAREST; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support [linear, bilinear, nearest]."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support bilinear & nearest."); } QnnParamWrapper mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_MODE, mode_qnn_scalar); param_tensor_names.push_back(mode_param.GetParamTensorName()); @@ -254,7 +254,7 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, } else if ("reflection" == padding_mode) { padding_mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_PADDING_MODE_REFLECTION; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support [zeros, border, reflection]."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support zeros, border & reflection."); } QnnParamWrapper padding_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_PADDING_MODE, padding_mode_qnn_scalar); param_tensor_names.push_back(padding_mode_param.GetParamTensorName()); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 3dc103046424e..d22edaf33eb1c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -839,23 +839,6 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord return Status::OK(); } -Status QnnBackendManager::SetContextPriority(ContextPriority context_priority) { - QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; - ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority, context_priority_config)); - - QnnContext_Config_t* configs[] = {&context_priority_config, nullptr}; - for (const auto& context_handle : contexts_) { - auto result = qnn_interface_.contextSetConfig(context_handle, (const QnnContext_Config_t**)configs); - ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to set context priority for context handle: ", context_handle); - } - - return Status::OK(); -} - -Status QnnBackendManager::ResetContextPriority() { - return SetContextPriority(context_priority_); -} - Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { if (true == context_created_) { LOGS_DEFAULT(INFO) << "Context created already."; @@ -1443,33 +1426,13 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, return Status::OK(); } -Status QnnBackendManager::SetRpcPowerConfigs(uint32_t htp_power_config_client_id, - uint32_t rpc_control_latency, - uint32_t rpc_polling_time) { +Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency) { // This function is called in QNN EP's OnRunStart() even if QNN backend setup failed and the model is assigned // to a different EP. Therefore, we have to check that backend setup actually completed before trying to // set RPC control latency. Otherwise, this causes a segfault because the QNN backend library is unloaded. ORT_RETURN_IF_NOT(backend_setup_completed_, "Cannot set HTP RPC control latency if backend setup is not complete."); - - constexpr int kNumRpcPollingPowerConfigs = 2; - std::vector rpc_power_configs; - rpc_power_configs.reserve(kNumRpcPollingPowerConfigs); - - // Set rpc control latency here if (rpc_control_latency != 0) { - auto& rpc_control_latency_cfg = rpc_power_configs.emplace_back(); - rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; - rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; - } - - // Note: v68 does not support rpc polling mode - if (rpc_polling_time != 0) { - auto& rpc_polling_time_cfg = rpc_power_configs.emplace_back(); - rpc_polling_time_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; - rpc_polling_time_cfg.rpcPollingTimeConfig = rpc_polling_time; - } - - if (rpc_power_configs.size() > 0) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -1479,6 +1442,15 @@ Status QnnBackendManager::SetRpcPowerConfigs(uint32_t htp_power_config_client_id "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; + // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. + constexpr int kNumRpcPollingPowerConfigs = 2; + std::vector rpc_power_configs(kNumRpcPollingPowerConfigs); + QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0]; + // v68 doesn't support this. + QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1]; + rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; + rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; + rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; std::vector perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs); status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 2a71c7391b180..3e68df3024565 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -159,9 +159,8 @@ class QnnBackendManager : public std::enable_shared_from_this Status SetHtpPowerConfig(uint32_t htp_power_config_client_id, HtpPerformanceMode htp_performance_mode); - Status SetRpcPowerConfigs(uint32_t htp_power_config_client_id, - uint32_t rpc_control_latency, - uint32_t rpc_polling_time); + Status SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency); const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; } @@ -220,11 +219,6 @@ class QnnBackendManager : public std::enable_shared_from_this // For each node name, a mapping to the context handle will be created void ProcessContextFromBinListAsync(Qnn_ContextHandle_t handle, void* notifyParam); - // Sets the context priority to the given value, if valid - Status SetContextPriority(ContextPriority context_priority); - // Resets the context priority to the session default as defined by context_priority_ - Status ResetContextPriority(); - private: Status LoadBackend(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3acb3347acee1..236447cc95c3d 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1356,8 +1356,7 @@ QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* uint32_t device_id, uint32_t core_id, qnn::HtpPerformanceMode default_htp_performance_mode, - uint32_t default_rpc_control_latency, - uint32_t default_rpc_polling_time) + uint32_t default_rpc_control_latency) : qnn_backend_manager_(qnn_backend_manager) { Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_); is_htp_power_config_id_valid_ = rt.IsOK(); @@ -1368,10 +1367,9 @@ QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_, default_htp_performance_mode)); } - if (default_rpc_control_latency > 0 || default_rpc_polling_time > 0) { - ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcPowerConfigs(htp_power_config_id_, - default_rpc_control_latency, - default_rpc_polling_time)); + if (default_rpc_control_latency > 0) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcControlLatency(htp_power_config_id_, + default_rpc_control_latency)); } } } @@ -1402,8 +1400,7 @@ QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContex if (context_state_.retired_context_pool.empty()) { uint32_t core_id = 0; context = std::make_shared(qnn_backend_manager_.get(), device_id_, core_id, - default_htp_performance_mode_, default_rpc_control_latency_, - default_rpc_polling_time_); + default_htp_performance_mode_, default_rpc_control_latency_); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -1471,21 +1468,15 @@ Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_optio LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; } - uint32_t rpc_polling_time = 0; - if (qnn::HtpPerformanceMode::kHtpBurst != htp_performance_mode) { - rpc_polling_time = 9999; - } - if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), htp_performance_mode)); } - if (rpc_control_latency > 0 || rpc_polling_time > 0) { - ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcPowerConfigs(GetPerThreadContext().GetHtpPowerConfigId(), - rpc_control_latency, - rpc_polling_time)); + if (rpc_control_latency > 0) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcControlLatency(GetPerThreadContext().GetHtpPowerConfigId(), + rpc_control_latency)); } } @@ -1554,38 +1545,4 @@ OrtDevice QNNExecutionProvider::GetOrtDeviceByMemType(OrtMemType /* em_type */) return default_device_; } -Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span keys, - gsl::span values) { - if (keys.size() != values.size()) { - LOGS_DEFAULT(ERROR) << "SetEpDynamicOptions: number of keys (" << keys.size() - << ") does not equal number of values (" << values.size() << ")."; - } - auto key_it = keys.begin(); - auto value_it = values.begin(); - - while (key_it != keys.end() && value_it != values.end()) { - std::string key(*key_it); - std::string value(*value_it); - - if (key == kOrtEpDynamicOptionsWorkloadType) { - if (value == "Default") { - ORT_RETURN_IF_ERROR(qnn_backend_manager_->ResetContextPriority()); - } else if (value == "Efficient") { - ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetContextPriority(qnn::ContextPriority::LOW)); - } else { - LOGS_DEFAULT(ERROR) << "Invalid EP Workload Type: " << value; - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid EP Workload Type."); - } - } else { - LOGS_DEFAULT(ERROR) << "EP Dynamic Option \"" << key << "\" is not currently supported."; - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported EP Dynamic Option"); - } - - key_it++; - value_it++; - } - - return Status::OK(); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 6adf613932d66..06f9726ae96cf 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -57,9 +57,6 @@ class QNNExecutionProvider : public IExecutionProvider { OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; - Status SetEpDynamicOptions(gsl::span keys, - gsl::span value) override; - private: std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, @@ -99,7 +96,6 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t device_id_ = 0; qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; uint32_t default_rpc_control_latency_ = 0; - uint32_t default_rpc_polling_time_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; bool stop_share_ep_contexts_ = false; @@ -120,8 +116,7 @@ class QNNExecutionProvider : public IExecutionProvider { PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, uint32_t device_id, uint32_t core_id, qnn::HtpPerformanceMode default_htp_performance_mode, - uint32_t default_rpc_control_latency, - uint32_t default_rpc_polling_time); + uint32_t default_rpc_control_latency); ~PerThreadContext(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index 785177ce37788..c679ea1adb286 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -125,10 +125,8 @@ struct QnnEpFactory : OrtEpFactory { OrtHardwareDeviceType hw_type, const char* qnn_backend_type) : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} { - ort_version_supported = ORT_API_VERSION; GetName = GetNameImpl; GetVendor = GetVendorImpl; - GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -144,12 +142,7 @@ struct QnnEpFactory : OrtEpFactory { static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); - return factory->ep_vendor.c_str(); - } - - static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->ep_vendor_id; + return factory->vendor.c_str(); } static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { @@ -202,9 +195,8 @@ struct QnnEpFactory : OrtEpFactory { } const OrtApi& ort_api; - const std::string ep_name; // EP name - const std::string ep_vendor{"Microsoft"}; // EP vendor name - uint32_t ep_vendor_id{0x1414}; // Microsoft vendor ID + const std::string ep_name; // EP name + const std::string vendor{"Microsoft"}; // EP vendor name // Qualcomm vendor ID. Refer to the ACPI ID registry (search Qualcomm): https://uefi.org/ACPI_ID_List const uint32_t vendor_id{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 1e9fafe8aa323..90a4294fb47f0 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -7,25 +7,6 @@ #include "tensorrt_execution_provider_custom_ops.h" #include "tensorrt_execution_provider.h" -// The filename extension for a shared library is different per platform -#ifdef _WIN32 -#define LIBRARY_PREFIX -#define LIBRARY_EXTENSION ORT_TSTR(".dll") -#elif defined(__APPLE__) -#define LIBRARY_PREFIX "lib" -#define LIBRARY_EXTENSION ".dylib" -#else -#define LIBRARY_PREFIX "lib" -#define LIBRARY_EXTENSION ".so" -#endif - -#ifdef _WIN32 -#define ORT_DEF2STR_HELPER(x) L#x -#else -#define ORT_DEF2STR_HELPER(X) #X -#endif -#define ORT_DEF2STR(x) ORT_DEF2STR_HELPER(x) - namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose); @@ -77,31 +58,8 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& // Get all registered TRT plugins from registry LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ..."; TensorrtLogger trt_logger = GetTensorrtLogger(false); - try { - void* library_handle = nullptr; - const auto& env = onnxruntime::GetDefaultEnv(); -#if NV_TENSORRT_MAJOR < 10 - auto full_path = env.GetRuntimePath() + - PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION); -#else -#ifdef _WIN32 - auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin_" ORT_DEF2STR(NV_TENSORRT_MAJOR)) LIBRARY_EXTENSION); -#else - auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION ORT_TSTR("." ORT_DEF2STR(NV_TENSORRT_MAJOR))); -#endif -#endif - - ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, false, &library_handle)); + initLibNvInferPlugins(&trt_logger, ""); - bool (*dyn_initLibNvInferPlugins)(void* logger, char const* libNamespace); - ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "initLibNvInferPlugins", (void**)&dyn_initLibNvInferPlugins)); - if (!dyn_initLibNvInferPlugins(&trt_logger, "")) { - LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library was found but was not able to initialize default plugins."; - } - LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugins successfully loaded."; - } catch (const std::exception&) { - LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library is not on the path and is therefore ignored"; - } int num_plugin_creator = 0; auto plugin_creators = getPluginRegistry()->getAllCreators(&num_plugin_creator); std::unordered_set registered_plugin_names; diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index 113a3f31be7f9..e8140a4d59eab 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -193,21 +193,27 @@ class BucketCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); - - auto it = buckets_.find(buffer_size); - if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { - it->second.emplace_back(buffer); - } else { - wgpuBufferRelease(buffer); - } + pending_buffers_.emplace_back(buffer); } void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { - // no-op + for (auto& buffer : pending_buffers_) { + auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.emplace_back(buffer); + } else { + wgpuBufferRelease(buffer); + } + } + + pending_buffers_.clear(); } ~BucketCacheManager() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } for (auto& pair : buckets_) { for (auto& buffer : pair.second) { wgpuBufferRelease(buffer); @@ -236,6 +242,7 @@ class BucketCacheManager : public IBufferCacheManager { } std::unordered_map buckets_limit_; std::unordered_map> buckets_; + std::vector pending_buffers_; std::vector buckets_keys_; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 313a96ba25509..7f92ea4ed3776 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -52,28 +52,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T1", CastOpTypeConstraints()) .TypeConstraint("T2", CastOpTypeConstraints()), Cast); -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Cast, - kOnnxDomain, - 19, 20, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T1", CastOpTypeConstraints()) - .TypeConstraint("T2", CastOpTypeConstraints()), - Cast); -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Cast, - kOnnxDomain, - 21, 22, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T1", CastOpTypeConstraints()) - .TypeConstraint("T2", CastOpTypeConstraints()), - Cast); ONNX_OPERATOR_KERNEL_EX( Cast, kOnnxDomain, - 23, + 19, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", CastOpTypeConstraints()) diff --git a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc index 9f07e2d2a3988..f13e86c185928 100644 --- a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc @@ -146,24 +146,24 @@ Status ScatterND::ComputeInternal(ComputeContext& context) const { const auto* updates = context.Input(2); const auto& input_shape = input->Shape(); const auto& indices_shape = indices->Shape(); - auto* output = context.Output(0, input_shape); - const void* source = input->DataRaw(); - void* target = output->MutableDataRaw(); - // If source and target pointers are not equal (non-inplace operation), we need to copy the data. - if (target != source) { - ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output)); - } - if (indices_shape.Size() == 0) { - // If the indices are empty, we can return early. - return Status::OK(); - } auto indices_rank = indices_shape.NumDimensions(); auto last_index_dimension = static_cast(indices_shape[indices_rank - 1]); auto num_updates_elements = static_cast(input_shape.SizeFromDimension(last_index_dimension)); // TODO: support bool with components 4. const size_t components = 1; auto output_size = static_cast((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components); + auto* output = context.Output(0, input_shape); + if (output_size == 0) { + // If the output tensor is empty, we can return early. + return Status::OK(); + } MLDataType data_type = input->DataType(); + const void* source = input->DataRaw(); + void* target = output->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output)); + } ScatterNDProgram program(reduction_, data_type); program .CacheHint(static_cast(reduction_)) diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 7e8b434431781..39432db5113d1 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -172,8 +172,8 @@ Status Slice::ComputeInternal(ComputeContext& context) const { } if (step < 0) { // we are slicing in reverse - start = dim_value > 0 ? std::clamp(start, int64_t{0}, dim_value - 1) : 0; - end = dim_value > 0 ? std::clamp(end, int64_t{-1}, dim_value - 1) : -1; + start = std::clamp(start, int64_t{0}, dim_value - 1); + end = std::clamp(end, int64_t{-1}, dim_value - 1); // note that we are flipping start and end to switch to forward step signs.push_back(-1); steps.push_back(static_cast(-step)); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 6e09f494f4a8d..460d220ecf1b9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -123,9 +123,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Cast); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Cast); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip); @@ -457,9 +455,7 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), - KERNEL_CREATE_INFO_VERSIONED(19, 20, Cast), - KERNEL_CREATE_INFO_VERSIONED(21, 22, Cast), - KERNEL_CREATE_INFO(23, Cast), + KERNEL_CREATE_INFO(19, Cast), // // activations BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/README.md b/onnxruntime/core/providers/webgpu/wgsl_templates/README.md index 6bd2f98cc5713..c1a62e7fa7858 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/README.md +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/README.md @@ -64,7 +64,7 @@ This section includes instructions for how to use the template system in the dev 1. Create WGSL template files in `.wgsl.template` extension. - [Reference: Template Syntax](https://github.com/fs-eire/wgsl-template?tab=readme-ov-file#template-syntax) - - [Reference: Built-in Utilities](https://github.com/fs-eire/wgsl-template?tab=readme-ov-file#Utilities) + - [Reference: Built-in Utilities](#Utilities) - [Example: Pad](../tensor/pad.wgsl.template) 2. In the implementation of `YourProgram::GenerateShaderCode()`, load and use the generated template files. @@ -117,4 +117,4 @@ This section includes instructions for how to use the template system in the dev 1. Build ORT once with dynamic template mode 2. Launch wgsl-gen in watch mode 3. Run ORT to debug/validate the shader - 4. Make changes to the template files, and repeat step (c) + 4. Make changes to the template files, and repeat step (3) diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json index df1940ed6416b..7cde6c17f54e9 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json @@ -9,13 +9,13 @@ "version": "1.0.0", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.13" + "@fs-eire/wgsl-template": "^0.1.3" } }, "node_modules/@fs-eire/wgsl-template": { - "version": "0.1.13", - "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.13.tgz", - "integrity": "sha512-SOQjVCQCUmXb9qYr2E3CKNs88/FzINuhFJiobBEkSAsyKtJby9oFWGZnrEO+hIl/oDTLA01LbjiDxuf6TGHE/w==", + "version": "0.1.10", + "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.10.tgz", + "integrity": "sha512-F5qQZxNweZ3ZD3d9RNc/g3nTiW7jyaAVi7SlMOL4wOfXh+Nm/qca2DISNTf3kjpVqkoazMJGbZ6TPQ4a/vjw0g==", "license": "MIT", "dependencies": { "minimist": "^1.2.8" diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json index 246e7365531e0..34831ccddeb33 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json @@ -10,6 +10,6 @@ "author": "", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.13" + "@fs-eire/wgsl-template": "^0.1.3" } } diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 142d64caa64aa..e821265fff80d 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -99,93 +99,69 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n return true; } -// Check if a single input's rank of an ONNX op is supported by corresponding WebNN op. -bool IsInputRankSupported(const emscripten::val& wnn_limits, - const std::string_view webnn_op_type, - const std::string_view input_name, - const size_t input_rank, - const std::string_view node_name, - const logging::Logger& logger) { - const std::string webnn_op_type_str(webnn_op_type); - const std::string input_name_str(input_name); - - if (wnn_limits[webnn_op_type_str].isUndefined()) { - LOGS(logger, VERBOSE) << "WebNN op type: [" << webnn_op_type - << "] is not defined in WebNN MLOpSupportLimits."; - return false; - } - - const emscripten::val input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - - if (input_limits.isUndefined()) { - LOGS(logger, VERBOSE) << "Node name: [" << node_name - << "], WebNN op type: [" << webnn_op_type - << "], input [" << input_name - << "]: limits are not defined in WebNN MLOpSupportLimits."; - return false; - } - - const emscripten::val rank_range = input_limits["rankRange"]; - if (rank_range.isUndefined()) { - LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type - << "] input [" << input_name - << "]: missing 'rankRange' attribute."; - return false; - } - - const emscripten::val min_val = rank_range["min"]; - const emscripten::val max_val = rank_range["max"]; - if (min_val.isUndefined() || max_val.isUndefined()) { - LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type - << "] input [" << input_name - << "]: its 'rankRange' limits is missing valid 'min' or 'max' attributes."; - return false; - } - - size_t min_rank = min_val.as(); - size_t max_rank = max_val.as(); - if (input_rank < min_rank || input_rank > max_rank) { - LOGS(logger, VERBOSE) << "Node name: [" << node_name - << "] WebNN op type [" << webnn_op_type - << "] input [" << input_name << "] rank " << input_rank - << " is not in supported range [" << min_rank << ", " << max_rank << "]"; +// Check if all input tensor ranks of the given node are supported by WebNN. +bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { + const std::string_view op_type = node.OpType(); + const auto it = op_inputs_map.find(op_type); + if (it == op_inputs_map.end()) { + LOGS(logger, VERBOSE) << "Operator type: [" << op_type << "] is not found in the op inputs map."; return false; } - return true; -} - -bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { - const std::string_view onnx_op_type = node.OpType(); - const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type); + const auto& input_defs = node.InputDefs(); + const std::string_view webnn_op_type = it->second.opType; + const std::string webnn_op_type_str(webnn_op_type); - if (webnn_op_type.empty()) { - LOGS(logger, VERBOSE) << "ONNX op type: [" << onnx_op_type << "]'s corresponding WebNN op is not found."; - return false; - } + for (const auto& input : it->second.inputs) { + if (static_cast(input.index) >= input_defs.size() || input_defs[input.index] == nullptr) { + LOGS(logger, VERBOSE) << "Input index [" << input.index + << "] for operator type [" << op_type + << "], corresponding WebNN op type [" << webnn_op_type + << "], WebNN input name [" << input.name + << "] is invalid."; + return false; + } - std::vector inputs; - if (!GetWebNNOpInputs(onnx_op_type, inputs, logger)) { - return false; - } + std::vector input_shape; + if (!GetShape(*input_defs[input.index], input_shape, logger)) { + return false; + } - const auto& input_defs = node.InputDefs(); + const std::string input_name_str(input.name); + if (wnn_limits[webnn_op_type_str].isUndefined() || + wnn_limits[webnn_op_type_str][input_name_str].isUndefined()) { + LOGS(logger, VERBOSE) << "Operator type: [" << op_type + << "], input index: [" << input.index + << "], corresponding WebNN op type: " << webnn_op_type + << ", WebNN input name " << input.name + << " is not defined in wnn_limits."; + return false; + } - for (const auto& input : inputs) { - // If it is an optional input and is absent, skip. - if (!TensorExists(input_defs, input.index)) { - continue; + const auto& input_limits = wnn_limits[webnn_op_type_str][input_name_str]; + if (input_limits["rankRange"].isUndefined()) { + LOGS(logger, VERBOSE) << "Operator type: [" << op_type + << "], input index: [" << input.index + << "], corresponding WebNN op type: " << webnn_op_type + << ", WebNN input name " << input.name + << "'s rankRange is not defined."; + return false; } - std::vector shape; - if (!GetShape(*input_defs[input.index], shape, logger) || - !IsInputRankSupported(wnn_limits, webnn_op_type, input.name, - shape.size(), - node.Name(), logger)) { + int input_dim_size = static_cast(input_shape.size()); + int min_rank = input_limits["rankRange"]["min"].as(); + int max_rank = input_limits["rankRange"]["max"].as(); + + if (input_dim_size < min_rank || input_dim_size > max_rank) { + LOGS(logger, VERBOSE) << "Operator type: [" << op_type + << "], input index: [" << input.index + << "], corresponding WebNN op type: " << webnn_op_type + << ", WebNN input name: " << input.name + << ", input size " << input_dim_size + << " is not in supported range [" << min_rank << ", " << max_rank << "]"; return false; } } - return true; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 50e361ede221e..d59788600f997 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -216,13 +216,6 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger); -bool IsInputRankSupported(const emscripten::val& wnn_limits, - const std::string_view webnn_op_type, - const std::string_view input_name, - const size_t input_rank, - const std::string_view node_name, - const logging::Logger& logger); - // Get a set of nodes supported by WebNN EP. std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, @@ -251,33 +244,6 @@ inline std::string_view GetWebNNOpType(const std::string_view onnx_op_type) { return (it != op_inputs_map.end()) ? it->second.opType : ""; } -// Get corresponding input name of WebNN op type by ONNX op type from op_input_map -inline std::string_view GetWebNNInputName(const std::string_view onnx_op_type, const int input_index) { - const auto it = op_inputs_map.find(onnx_op_type); - - if (it != op_inputs_map.end()) { - for (const auto& input : it->second.inputs) { - if (input.index == input_index) { - return input.name; - } - } - } - - return ""; -} - -inline bool GetWebNNOpInputs(const std::string_view onnx_op_type, - std::vector& inputs, - const logging::Logger& logger) { - const auto it = op_inputs_map.find(onnx_op_type); - if (it == op_inputs_map.end()) { - LOGS(logger, VERBOSE) << "WebNN op inputs not found for op type: " << onnx_op_type; - return false; - } - inputs = it->second.inputs; - return true; -} - bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index fdf1709d87bac..fc630af8cf1e3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -18,6 +18,10 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const GraphViewer&, const Node& node, + WebnnDeviceType device_type, const logging::Logger& logger) const override; }; // Add operator related. @@ -61,6 +65,20 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } +// Operator support related. +bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer& /* initializers */, + const Node& node, + WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + return true; +} + void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 3c8e7fa34f7ed..b0ec006db6986 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -62,12 +62,13 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, int32_t input_type; if (!GetType(input, input_type, logger)) return false; - const std::string_view webnn_op_type = GetWebNNOpType(op_type); + if (webnn_op_type.empty()) + return false; + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits, - webnn_input_name, "input", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + webnn_input_name, "input", logger); } bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 851dc373923ac..280ffc83eae89 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -73,10 +73,9 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod return false; } - const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); + std::string webnn_input_name = op_type == "PRelu" ? "input" : "a"; std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A"; - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger); } void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index db5e8cd51656c..8589237617745 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -75,8 +75,7 @@ bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); } void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index e0bfb3bd682e8..b9383a63fe307 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -324,7 +324,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N x_zero_point = model_builder.CreateOrGetConstant(x_type, 0); } - // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to default value 1.0f. + // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to deafult value 1.0f. // The x_zero_point must be a scalar and the scale input should have the same shape as the zero point input. // So the x_scale must be a scalar too. x_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); diff --git a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc index f3c392b608e45..7528d9ad2ff51 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc @@ -77,6 +77,10 @@ bool CumSumOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const std::string axis_name = GetTensorName(input_defs, 1); // Inputs contain optional 'axis' input. const auto* init = graph_viewer.GetConstantInitializer(axis_name); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc index 37a00fcb12abd..c22dd9e97bb1a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -21,6 +21,11 @@ class DropoutOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const GraphViewer&, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; // Add operator related. @@ -60,13 +65,26 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val options = emscripten::val::object(); options.set("label", output_defs[1]->Name() + "_identity"); // Add additional identity op in case the mask is the output of a WebNN graph, - // because WebNN does not support a constant operand as output. + // beacuse WebNN does not support a constant operand as output. emscripten::val mask_output = model_builder.GetBuilder().call("identity", one_constant, options); model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output)); } return Status::OK(); } +// Operator support related. +bool DropoutOpBuilder::IsOpSupportedImpl(const GraphViewer&, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + return true; +} + void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc index 6aa760c0f4baf..e5b4fcddc4221 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -28,8 +28,6 @@ class EinsumOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; - bool HasSupportedOutputsImpl(const Node& /* node */, const emscripten::val& /* wnn_limits */, - const logging::Logger& /* logger */) const override; }; // Helper functions, thanks for DML EP's OperatorHelper. @@ -44,6 +42,12 @@ enum class RecognizedOperatorType { Total, }; +struct RecognizedOperatorInfo { + RecognizedOperatorType recognized_operator_type; + std::initializer_list component_ranks; + std::initializer_list label_indices; +}; + struct Component { uint32_t label_index_begin; uint32_t label_index_end; @@ -594,7 +598,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } } - // transpose input + // tranpose input std::vector permutation(input_labels.size()); for (uint32_t idx = 0; idx < input_labels.size(); idx++) { if (idx != diagonal_idx_1 && idx != diagonal_idx_2) { @@ -616,7 +620,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options_trilu.set("upper", false); output = model_builder.GetBuilder().call("triangular", output, options_trilu); // tril - // reduceSum to achieve the diagonal values + // reducesum to achieve the diagonal values std::vector input_shape; std::vector reduced_axes; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); @@ -696,6 +700,12 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + if (input_defs.size() > 2) { + // TODO: Support more than two inputs. + LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; + return false; + } + NodeAttrHelper helper(node); const auto equation = helper.Get("equation", std::string(" ")); std::vector label_indices; @@ -714,6 +724,13 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } + RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, + output_dimensions); + if (recognized_operator_type == RecognizedOperatorType::None) { + LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; + return false; + } + return true; } @@ -721,14 +738,9 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - if (input_defs.size() > 2) { - // TODO: Support more than two inputs. - LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; - return false; - } - const std::string_view op_type = node.OpType(); - int32_t input0_type, input1_type; + int32_t input0_type; + int32_t input1_type; bool has_input1 = TensorExists(input_defs, 1); if (!GetType(*input_defs[0], input0_type, logger) || @@ -742,13 +754,6 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod return false; } - std::vector input0_shape; - std::vector input1_shape; - if (!GetShape(*input_defs[0], input0_shape, logger) || - (has_input1 && !GetShape(*input_defs[1], input1_shape, logger))) { - return false; - } - NodeAttrHelper helper(node); const auto equation = helper.Get("equation", std::string(" ")); std::vector label_indices; @@ -765,54 +770,17 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, output_dimensions); - std::string_view decomposed_op_type; if (recognized_operator_type == RecognizedOperatorType::None) { LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; return false; - } else if (recognized_operator_type == RecognizedOperatorType::Multiply) { - decomposed_op_type = "Mul"; - } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { - decomposed_op_type = "ReduceSum"; - } else if (recognized_operator_type == RecognizedOperatorType::Diagonal) { - decomposed_op_type = "Trilu"; - } else if (recognized_operator_type == RecognizedOperatorType::Transpose) { - decomposed_op_type = "Transpose"; } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { - decomposed_op_type = "MatMul"; - } else { // Identity - // For the Identity case, we simply forward the input to the output without any modification. - return true; - } - - const std::string_view wnn_input0_name = GetWebNNInputName(decomposed_op_type, 0); - const std::string_view decompose_wnn_op_type = GetWebNNOpType(decomposed_op_type); - if (decompose_wnn_op_type.empty() || - !IsDataTypeSupportedByWebNNOp(op_type, decompose_wnn_op_type, input0_type, - wnn_limits, wnn_input0_name, "inputs", logger) || - !IsInputRankSupported(wnn_limits, decompose_wnn_op_type, wnn_input0_name, - input0_shape.size(), node.Name(), logger)) { - return false; - } - - if (has_input1) { - const std::string_view wnn_input1_name = GetWebNNInputName(decomposed_op_type, 1); - return IsDataTypeSupportedByWebNNOp(op_type, decompose_wnn_op_type, input1_type, - wnn_limits, wnn_input1_name, "inputs", logger) && - IsInputRankSupported(wnn_limits, decompose_wnn_op_type, wnn_input1_name, - input1_shape.size(), node.Name(), logger); + // Map to WebNN's gemm or matmul + return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger); + } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { + return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger); + } else { + return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger); } - - return true; -} - -bool EinsumOpBuilder::HasSupportedOutputsImpl(const Node& /* node */, - const emscripten::val& /* wnn_limits */, - const logging::Logger& /* logger */) const { - // The Einsum op produces output with the same data type as its input. - // Therefore, checking the output data type is unnecessary. - // This override prevents calling the base class implementation, as the base implementation - // would return false due to Einsum being a decomposed op. - return true; } void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index ae4c3705fdb2e..06beb56415609 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -56,14 +56,14 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const N const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type, indices_type; + int32_t data_type; + int32_t indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index af508c2800f4b..9200c596c0e53 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -61,14 +61,14 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& n const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type, indices_type; + int32_t data_type; + int32_t indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 7111a8f6beaa3..d84c70032e1d1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -20,6 +20,8 @@ class GatherOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. + bool IsOpSupportedImpl(const GraphViewer&, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -48,20 +50,38 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. + +bool GatherOpBuilder::IsOpSupportedImpl(const GraphViewer&, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto rank = input_shape.size(); + if (rank < 1) { + LOGS(logger, VERBOSE) << "Gather only supports input shapes >= 1D, but input is " + << rank << "d shape"; + return false; + } + + return true; +} + bool GatherOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t input_type, indices_type; - + int32_t input_type; + int32_t indices_type; if (!GetType(input, input_type, logger) || !GetType(indices, indices_type, logger)) return false; return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 7af17fdc5db78..02f46c85d1d06 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -91,7 +91,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); std::vector a_zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[2], a_zero_point_shape, logger), "Cannot get shape of a_zero_point"); - // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to default value 1.0f. + // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to deafult value 1.0f. // The scale input should have the same shape as the zero point input. a_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, @@ -268,45 +268,11 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - if (op_type == "Gemm") { - return IsInputRankSupportedByOp(node, wnn_limits, logger) && - IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); - } else if (op_type == "MatMulInteger") { - // Check up to 4 inputs for MatMulInteger - for (size_t i = 0; i < input_defs.size(); ++i) { - std::vector shape; - if (!GetShape(*input_defs[i], shape, logger)) { - return false; - } - - // We made workaround to support 1D for input A and B, skip further checks if they are 1D - if (i <= 1 && shape.size() == 1) { - continue; - } - - // For DequantizeLinear, input indices: 0 (x), 1 (scale), 2 (zero_point) - if (!IsInputRankSupported(wnn_limits, "dequantizeLinear", - (i < 2) ? "input" : "zeroPoint", - shape.size(), node.Name(), logger)) { - return false; - } - } + if (op_type == "MatMulInteger") { + // The first decomposed op of MatMulInteger is DequantizeLinear, and so + // we only need to ensure it supports the input0_type. return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger); - } else { // MatMul - for (int i = 0; i < 2; ++i) { - std::vector shape; - if (!GetShape(*input_defs[i], shape, logger)) { - return false; - } - - if (shape.size() == 1) { - continue; - } - - if (!IsInputRankSupported(wnn_limits, "matmul", (i == 0) ? "a" : "b", shape.size(), node.Name(), logger)) { - return false; - } - } + } else { return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index 95e75a3083cc2..dfe80dd419092 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -219,8 +219,7 @@ bool GruOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); } bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 55d468c4843cb..42940083cad8e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -91,10 +91,8 @@ bool LogicalOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no } } - const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); std::string onnx_input_name = op_type == "Not" ? "X" : "A"; - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc index e8aab725375ad..8936bda875aef 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc @@ -21,6 +21,8 @@ class LRNOpBuilder : public BaseOpBuilder { // Operator support related. private: + bool IsOpSupportedImpl(const GraphViewer&, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, @@ -126,10 +128,11 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, - const emscripten::val& wnn_limits, const logging::Logger& logger) const { +bool LRNOpBuilder::IsOpSupportedImpl(const GraphViewer&, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) return false; @@ -140,6 +143,12 @@ bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } + return true; +} + +bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); int32_t input_type = 0; if (!GetType(*input_defs[0], input_type, logger)) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 04d59e2f30d15..09e584bc66f8a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -242,8 +242,7 @@ bool LstmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc index 9ab403b7051d2..111d03571e974 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc @@ -48,7 +48,7 @@ void MatMulNBitsBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons // DequantizeLinear + Transpose + MatMul. Given that the CPU EP currently only supports // 4-bit quantization, we only handle 4-bit quantization here. // -// To align with WebNN's dequantizeLinear op constraints, the following transformations are +// To align with WebNN's dequantizeLinear op contraints, the following transformations are // required for MatMulNBits inputs: // 1. B: must be a constant initializer and registered as a 'uint4' WebNN constant with shape // [N, n_blocks_per_col, blob_size * 2]. @@ -159,6 +159,10 @@ bool MatMulNBitsBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const logging::Logger& logger) const { const auto& name = node.Name(); const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } // Inputs B and zero_points (if present) must be initializers if (!graph_viewer.GetConstantInitializer(input_defs[1]->Name())) { // B @@ -189,10 +193,6 @@ bool MatMulNBitsBuilder::HasSupportedInputsImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } int32_t A_type = 0; int32_t B_type = 0; @@ -227,13 +227,10 @@ bool MatMulNBitsBuilder::HasSupportedInputsImpl(const GraphViewer&, return false; } - // Data type: Currently, only 4-bit quantization is supported, represented as the uint4 data type in WebNN. - // Ensure that the uint4 data type is supported by WebNN's dequantizeLinear op. - // Input rank: Only the rank of the first input (A) is flexible. Verify that its rank is supported by - // WebNN's matmul op. + // We only support 4-bit quantization, which is represented as the uint4 data type in WebNN. + // Ensure that uint4 is supported. return IsDataTypeSupportedByOp("DequantizeLinear", ONNX_NAMESPACE::TensorProto_DataType_UINT4, - wnn_limits, "input", "x", logger) && - IsInputRankSupported(wnn_limits, "matmul", "a", input_shape.size(), node.Name(), logger); + wnn_limits, "input", "x", logger); } bool MatMulNBitsBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 9f5ac6ef15735..4e4014e3553ea 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -20,6 +20,8 @@ class MaxMinOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node& node, + WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -66,6 +68,25 @@ Status MaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. +bool MaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer&, + const Node& node, + WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + if (input_defs.size() < 1) { + LOGS(logger, VERBOSE) << op_type << " requires at least one input (data)"; + return false; + } + + return true; +} + bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -87,8 +108,7 @@ bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 9fb643f055ef3..148eacac98e4a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -46,14 +46,28 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); + std::vector scale_shape; const size_t scale_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 2 : 1; + ORT_RETURN_IF_NOT(GetShape(*input_defs[scale_input_index], scale_shape, logger), "Cannot get scale shape"); + const auto scale_size = scale_shape.size(); + // Except LayerNormalization, other normalization ops' scale input should be 1-D. + if (op_type == "LayerNormalization") { + ORT_RETURN_IF_NOT(scale_size >= 1 && scale_size <= rank, + "The scale size should be less than or equal to input size."); + } else { + ORT_RETURN_IF_NOT(scale_size == 1, "The scale size should be one."); + } + emscripten::val scale = model_builder.GetOperand(input_defs[scale_input_index]->Name()); options.set("scale", scale); const size_t bias_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 3 : 2; emscripten::val bias = emscripten::val::undefined(); if (TensorExists(input_defs, bias_input_index)) { - // Bias input exists. + // Bias input exists, and bias's shape should be the same as scale's shape. + std::vector bias_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[bias_input_index], bias_shape, logger), "Cannot get bias shape"); + ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape."); bias = model_builder.GetOperand(input_defs[bias_input_index]->Name()); options.set("bias", bias); } @@ -265,6 +279,12 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, VERBOSE) << "Cannot get input shape."; + return false; + } + const auto& output_defs = node.OutputDefs(); if (op_type == "SkipSimplifiedLayerNormalization") { if (output_defs.size() > 4) { @@ -296,28 +316,33 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); - - std::vector input_types; - bool all_types_valid = true; - - // Iterate through all inputs and check their existence and types - for (size_t i = 0; i <= input_defs.size(); ++i) { - if (TensorExists(input_defs, i)) { - int32_t input_type; - if (!GetType(*input_defs[i], input_type, logger)) { - all_types_valid = false; - break; - } - input_types.push_back(input_type); - } - } - - // Return false if any input type is invalid - if (!all_types_valid) { + int32_t input0_type; // input data type + int32_t input1_type; // scale data type + int32_t input2_type; // B data type + int32_t input3_type; // mean data type + int32_t input4_type; // var data type + bool has_input2 = TensorExists(input_defs, 2); + bool has_input3 = TensorExists(input_defs, 3); + bool has_input4 = TensorExists(input_defs, 4); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + (has_input2 && !GetType(*input_defs[2], input2_type, logger)) || + (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || + (has_input4 && !GetType(*input_defs[4], input4_type, logger))) { return false; } - // Check if all input data types are the same + std::vector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); + } + if (has_input3) { + input_types.push_back(input3_type); + } + if (has_input4) { + input_types.push_back(input4_type); + } if (!AreDataTypesSame(op_type, input_types, logger)) { return false; } @@ -330,29 +355,13 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); const std::string_view webnn_input_name = GetWebNNOpFirstInputName(decomposed_op_type); if (!IsDataTypeSupportedByWebNNOp( - op_type, webnn_op_type, input_types[0], wnn_limits, webnn_input_name, "input", logger)) { + op_type, webnn_op_type, input0_type, wnn_limits, webnn_input_name, "input", logger)) { return false; } } - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } - // It's complicated to check all the decomposed ops' input rank support. - // Ensure at least the first input rank is supported by the decomposed ops (pow and div accept the first input). - return IsInputRankSupported(wnn_limits, "pow", "a", input_shape.size(), node.Name(), logger) && - IsInputRankSupported(wnn_limits, "div", "a", input_shape.size(), node.Name(), logger); + return true; } else { - bool is_data_type_supported = IsDataTypeSupportedByOp(op_type, input_types[0], wnn_limits, "input", "X", logger); - if (op_type == "InstanceNormalization") { - // Skip input rank check for InstanceNormalization, as we will reshape the input to 4D if necessary. - return is_data_type_supported; - } - - // For other ops, check both data type and input rank compatibility. - bool is_input_rank_supported = IsInputRankSupportedByOp(node, wnn_limits, logger); - return is_input_rank_supported && is_data_type_supported; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index 5d921c5176a64..f2a3f08b73148 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -133,6 +133,20 @@ bool PoolOpBuilder::IsOpSupportedImpl(const GraphViewer&, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) + << op_type << " only supports rank-4 tensor, input [" + << input_defs[0]->Name() << "] has actual dim count " << input_size; + return false; + } + NodeAttrHelper helper(node); if (op_type == "AveragePool" || op_type == "LpPool" || op_type == "MaxPool") { if (helper.Get("kernel_shape", std::vector{1, 1}).size() != 2) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index 053c41773db40..dd25fb9bf9315 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -167,8 +167,7 @@ bool QDQOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsInputRankSupportedByOp(node, wnn_limits, logger) && - IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) && (!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger)); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 6ea9b0a440d93..a3a0397eda4a3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -128,10 +128,16 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto& op_type = node.OpType(); const std::string axes_name = GetTensorName(input_defs, 1); // If the optional input 'axes' is provided, it must be an initializer. if (!axes_name.empty() && !graph_viewer.GetConstantInitializer(axes_name)) { - LOGS(logger, VERBOSE) << "Input axes of " << node.OpType() << " must be a constant"; + LOGS(logger, VERBOSE) << "Input axes of " << op_type << " must be a constant"; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index 0444ae3afb56a..8cbb381e0f53e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -79,6 +79,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto& perm_name = input_defs[1]->Name(); const auto* perm_init = graph_viewer.GetConstantInitializer(perm_name); if (!perm_init) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index 37071b1030e11..893ca9d2419c7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -285,7 +285,7 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build sign_buffer.set(1, 1.0f); } else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { if (model_builder.IsFloat16ArrayAvailable()) { - // Float16Array is available - use Float16Array. + // Float16Array is avaliable - use Float16Array. sign_buffer = emscripten::val::global("Float16Array").new_(2); sign_buffer.set(0, -1.0f); sign_buffer.set(1, 1.0f); diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index c2974bd988f6b..f894e8bfbd517 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -71,6 +71,7 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; + const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -84,11 +85,8 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const return false; } - const std::string_view op_type = node.OpType(); - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index a7788cfd847e9..e61ac3dcc9617 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -63,6 +63,7 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; + const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -75,10 +76,9 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& if (data_type != updates_type) { return false; } - const std::string_view op_type = node.OpType(); + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 5efbfe932c602..8853891ff8ed6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -136,6 +136,10 @@ bool SliceOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const No const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } if (input_defs.size() < 3) { LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data, starts, ends) but got " @@ -162,17 +166,10 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& input = *input_defs[0]; - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } - + const std::string_view op_type = node.OpType(); int32_t input_type; - if (!GetType(input, input_type, logger)) { + if (!GetType(input, input_type, logger)) return false; - } - - const std::string_view op_type = node.OpType(); // If there is step < 0, check data type support of reverse. if (TensorExists(input_defs, 4)) { @@ -181,15 +178,13 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con if (!init || !ReadIntArrayFrom1DTensor(*init, steps, graph_viewer, logger)) return false; if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) { - if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger) || - !IsInputRankSupported(wnn_limits, "reverse", "input", input_shape.size(), node.Name(), logger)) { + if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) { return false; } } } - return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger); } void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 99d137f81864c..23e73bb8f1e74 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -18,6 +18,11 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const GraphViewer&, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -41,6 +46,20 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } +// Operator support related. + +bool SoftmaxOpBuilder::IsOpSupportedImpl(const GraphViewer&, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + return true; +} + void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 7e34e35ebac16..1ba6df9febf14 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -127,6 +127,9 @@ bool SqueezeUnsqueezeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewe const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; if (input_defs.size() < 1) { LOGS(logger, ERROR) << op_type << " has no input tensor"; diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 8973757a24e99..7a7f64b1ec96d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -66,8 +66,7 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no return false; } - return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc index 24d96588559ae..29b232026d7df 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc @@ -77,6 +77,15 @@ bool TileOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, return false; } + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + if (input_shape.empty()) { + LOGS(logger, VERBOSE) << "Tile does not support empty input shape"; + return false; + } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc index 7a4d172c556fa..5a267557b9454 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc @@ -76,6 +76,15 @@ bool TriangularOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto input_size = input_shape.size(); + if (input_size < 2) { + LOGS(logger, VERBOSE) << "Triangular only supports input size >= 2D shape, input is " + << input_size << "d shape"; + return false; + } const std::string diagonal_name = GetTensorName(input_defs, 1); // Inputs contain optional 'diagonal' input. diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 1c30fed7a7916..5e860eea7cac9 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -47,7 +47,6 @@ constexpr std::array supported_fallback // Use ONNX-to-ONNX op mapping to improve the search complexity for WebNN ops in the op_inputs_map. const std::map> decomposed_op_map = { {"ConvInteger", {"Cast", "Conv", "DequantizeLinear"}}, - {"Einsum", {"MatMul", "Mul", "ReduceSum", "Reshape", "Transpose", "Trilu"}}, {"GroupQueryAttention", {"Add", "Cast", "Concat", "CumSum", "Div", "Expand", "Less", "MatMul", "Reshape", "ScatterND", "Softmax", "Transpose", "Where"}}, @@ -140,7 +139,7 @@ const std::unordered_map op_inputs_map = { {"Mul", {"mul", {{0, "a"}, {1, "b"}}}}, {"Pow", {"pow", {{0, "a"}, {1, "b"}}}}, {"Concat", {"concat", {{0, "inputs"}}}}, - {"Not", {"logicalNot", {{0, "a"}}}}, + {"Not", {"logicalNot", {{0, "input"}}}}, {"Flatten", {"reshape", {{0, "input"}}}}, {"LpPool", {"l2Pool2d", {{0, "input"}}}}, {"Reshape", {"reshape", {{0, "input"}}}}, @@ -160,6 +159,7 @@ const std::unordered_map op_inputs_map = { {"Softsign", {"softsign", {{0, "input"}}}}, {"Unsqueeze", {"reshape", {{0, "input"}}}}, {"Or", {"logicalOr", {{0, "a"}, {1, "b"}}}}, + {"Einsum", {"matmul", {{0, "a"}, {1, "b"}}}}, {"HardSwish", {"hardSwish", {{0, "input"}}}}, {"LeakyRelu", {"leakyRelu", {{0, "input"}}}}, {"MatMul", {"matmul", {{0, "a"}, {1, "b"}}}}, diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index d2cd0639affd0..4468831181d42 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -78,7 +78,7 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; const bool is_float16array_available_ = !emscripten::val::global("Float16Array").isUndefined() && - !emscripten::val::global("Float16Array")["from"].isUndefined(); + emscripten::val::global("Float16Array").hasOwnProperty("from"); emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 59b0992d827e1..d910e3ea74b57 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -128,35 +128,6 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, - _In_ OrtModelCompilationOptions* ort_model_compile_options, - const ORTCHAR_T* output_directory, - const ORTCHAR_T* model_name) { - API_IMPL_BEGIN -#if !defined(ORT_MINIMAL_BUILD) - auto model_compile_options = reinterpret_cast(ort_model_compile_options); - - std::string output_dir = PathToUTF8String(output_directory); - if (output_dir.empty()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output directory: path is empty"); - } - - std::string model_name_str = ToUTF8String(model_name); - if (model_name_str.empty()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid model name: string is empty"); - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_dir, model_name_str)); - return nullptr; -#else - ORT_UNUSED_PARAMETER(ort_model_compile_options); - ORT_UNUSED_PARAMETER(output_directory); - ORT_UNUSED_PARAMETER(model_name); - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); -#endif // !defined(ORT_MINIMAL_BUILD) - API_IMPL_END -} - ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExternalInitializersFile, _In_ OrtModelCompilationOptions* ort_model_compile_options, const ORTCHAR_T* external_initializers_file_path, @@ -277,7 +248,6 @@ static constexpr OrtCompileApi ort_compile_api = { // End of Version 22 - DO NOT MODIFY ABOVE &OrtCompileAPI::ModelCompilationOptions_SetFlags, - &OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 93cc5dbf20fce..5f11b894f2004 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -30,7 +30,5 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options, size_t flags); -ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options, - _In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index a0904c32011a7..daccd24453371 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -16,10 +16,6 @@ struct ForwardToFactory { return static_cast(this_ptr)->GetVendor(); } - static uint32_t ORT_API_CALL GetVendorId(const OrtEpFactory* this_ptr) noexcept { - return static_cast(this_ptr)->GetVendorId(); - } - static const char* ORT_API_CALL GetVersion(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetVersion(); } diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index fa4ef2515ca92..b289010cc6c5b 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -14,19 +14,17 @@ namespace onnxruntime { using Forward = ForwardToFactory; -EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, +EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, GetSupportedFunc&& get_supported_func, CreateFunc&& create_func) : ep_name_{ep_name}, vendor_{vendor}, - vendor_id_{vendor_id}, get_supported_func_{std::move(get_supported_func)}, create_func_{create_func} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; OrtEpFactory::GetVendor = Forward::GetVendor; - OrtEpFactory::GetVendorId = Forward::GetVendorId; OrtEpFactory::GetVersion = Forward::GetVersion; OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index ee08e2233c529..087c0c60f8f4e 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -33,13 +33,12 @@ class EpFactoryInternal : public OrtEpFactory { const OrtSessionOptions* session_options, const OrtLogger* logger, std::unique_ptr* ep)>; - EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, + EpFactoryInternal(const std::string& ep_name, const std::string& vendor, GetSupportedFunc&& get_supported_func, CreateFunc&& create_func); const char* GetName() const noexcept { return ep_name_.c_str(); } const char* GetVendor() const noexcept { return vendor_.c_str(); } - uint32_t GetVendorId() const noexcept { return vendor_id_; } const char* GetVersion() const noexcept; OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -68,7 +67,6 @@ class EpFactoryInternal : public OrtEpFactory { private: const std::string ep_name_; // EP name library was registered with const std::string vendor_; // EP vendor name - const uint32_t vendor_id_; // EP vendor ID const GetSupportedFunc get_supported_func_; // function to return supported devices const CreateFunc create_func_; // function to create the EP instance diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index ce5736f601b45..25f70f7549a16 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -61,8 +61,7 @@ std::unique_ptr EpLibraryInternal::CreateCpuEp() { }; std::string ep_name = kCpuExecutionProvider; - auto cpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - get_supported, create_cpu_ep); + auto cpu_factory = std::make_unique(ep_name, "Microsoft", get_supported, create_cpu_ep); return std::make_unique(std::move(cpu_factory)); } @@ -123,8 +122,7 @@ std::unique_ptr EpLibraryInternal::CreateDmlEp() { return nullptr; }; - auto dml_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - is_supported, create_dml_ep); + auto dml_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_dml_ep); return std::make_unique(std::move(dml_factory)); } @@ -172,8 +170,7 @@ std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { return nullptr; }; - auto webgpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - is_supported, create_webgpu_ep); + auto webgpu_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_webgpu_ep); return std::make_unique(std::move(webgpu_factory)); } diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc index 70937bdc5d3e8..73423a4744576 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/ep_library_provider_bridge.cc @@ -72,7 +72,6 @@ Status EpLibraryProviderBridge::Load() { auto internal_factory = std::make_unique(factory->GetName(factory), factory->GetVendor(factory), - factory->GetVendorId(factory), is_supported_fn, create_fn); factory_ptrs_.push_back(internal_factory.get()); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f147242da668f..86a61a4d0ee74 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -423,13 +423,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, { if (!external_intra_op_thread_pool_) { bool allow_intra_op_spinning = -#if !defined(ORT_CLIENT_PACKAGE_BUILD) session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "1") == "1"; -#else - // default KOrtSessionOptionsConfigAllowIntraOpSpinning to "0" for ORT builds targeting client/on-device workloads, - // to reduce CPU utilization and improve power efficiency. - session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "0") == "1"; -#endif OrtThreadPoolParams to = session_options_.intra_op_param; std::basic_stringstream ss; if (to.name) { @@ -467,13 +461,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL) { if (!external_inter_op_thread_pool_) { bool allow_inter_op_spinning = -#if !defined(ORT_CLIENT_PACKAGE_BUILD) session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "1") == "1"; -#else - // default kOrtSessionOptionsConfigAllowInterOpSpinning to "0" for ORT builds targeting client/on-device workloads, - // to reduce CPU utilization and improve power efficiency. - session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "0") == "1"; -#endif OrtThreadPoolParams to = session_options_.inter_op_param; to.auto_set_affinity = to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL; std::basic_stringstream ss; diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index bbb110033f54c..5de0f03fafc08 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -72,8 +72,8 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod if (log_manager != nullptr && log_manager->HasDefaultLogger()) { const logging::Logger& logger = log_manager->DefaultLogger(); LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size() - << ") exceeds limit of " << ConfigOptions::kMaxValueLength << " characters." - << "ORT will still generate the expected output file, but EPs will see an empty " + << ") exceeds limit of " << ConfigOptions::kMaxKeyLength << " characters." + << "ORT will still generated the expected output file, but EPs will see an empty " << "output model path in SessionOption's ConfigOptions."; } } @@ -98,36 +98,6 @@ Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr a return Status::OK(); } -Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::string& output_directory, - const std::string& model_name) { - if (output_directory.empty() || model_name.empty()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir or model_name is empty."); - } - - std::filesystem::path output_dir_path(output_directory); - if (output_dir_path.has_filename() && output_dir_path.extension() == "") { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir is not a valid directory."); - } - - std::filesystem::path ctx_model_path = output_directory / std::filesystem::path(model_name); - - if (ctx_model_path.string().size() <= ConfigOptions::kMaxValueLength) { - ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, - ctx_model_path.string().c_str())); - } else { - logging::LoggingManager* log_manager = env_.GetLoggingManager(); - if (log_manager != nullptr && log_manager->HasDefaultLogger()) { - const logging::Logger& logger = log_manager->DefaultLogger(); - LOGS(logger, WARNING) << "output_directory length with model_name length together exceeds limit of " - << ConfigOptions::kMaxValueLength << " characters." - << "ORT will still generate the expected output file, but EPs will see an empty " - << "output path in SessionOption's ConfigOptions."; - } - } - - return Status::OK(); -} - Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_model) { ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry( kOrtSessionOptionEpContextEmbedMode, embed_ep_context_in_model ? "1" : "0")); @@ -176,7 +146,7 @@ Status ModelCompilationOptions::ResetOutputModelSettings() { ep_context_gen_options.output_model_buffer_ptr = nullptr; ep_context_gen_options.output_model_buffer_size_ptr = nullptr; ep_context_gen_options.output_model_buffer_allocator = nullptr; - return Status::OK(); + return session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ""); } Status ModelCompilationOptions::CheckInputModelSettings() const { diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 2824df863013d..f96f0317cdaca 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -72,16 +72,6 @@ class ModelCompilationOptions { Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); - /// - /// Sets information relate to EP context binary file. - /// EP use this information to decide the location and context binary file name. - /// Used while compiling model with input and output in memory buffer - /// - /// The folder path to the generated context binary file - /// Model name used to decide the context binary file name: [model_name]_[ep].bin - /// Status indicating potential error - Status SetEpContextBinaryInformation(const std::string& output_directory, const std::string& model_name); - /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext /// nodes. Defaults to false (dumped to file). diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index db2a62c77d1bc..e7f60fd48a14f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2591,29 +2591,6 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets) { - API_IMPL_BEGIN - if (num_operator_sets == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_operator_sets' argument is NULL"); - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetNumOperatorSets(*num_operator_sets)); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::Graph_GetOperatorSets, _In_ const OrtGraph* graph, - _Out_writes_(num_operator_sets) const char** domains, - _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets) { - API_IMPL_BEGIN - gsl::span domains_span(domains, num_operator_sets); - gsl::span versions_span(opset_versions, num_operator_sets); - ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetOperatorSets(domains_span, versions_span)); - - return nullptr; - API_IMPL_END -} - ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs) { API_IMPL_BEGIN if (num_inputs == nullptr) { @@ -2714,91 +2691,6 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _O API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, - _In_ const OrtNode** nodes, - _In_ size_t num_nodes, - _Outptr_ OrtGraph** dst_graph) { - API_IMPL_BEGIN - - if (num_nodes == 0) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0"); - } - - const EpGraph* ep_graph = EpGraph::ToInternal(src_graph); - if (ep_graph == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph."); - } - const Graph& graph = ep_graph->GetGraphViewer().GetGraph(); - - // Create a GraphViewer with filtered info - std::unique_ptr indexed_sub_graph = std::make_unique(); - std::unique_ptr metadef = std::make_unique(); - metadef->name = "sub_graph"; - metadef->since_version = 1; - std::unordered_set outputs; - std::unordered_set initializers; - - auto add_inputs = [&](ConstPointerContainer> defs) { - for (const auto* def : defs) { - if (def->Exists()) { - // not the output of a previous node - if (outputs.count(def->Name()) == 0) { - metadef->inputs.push_back(def->Name()); - } else { - // consumed by node so no longer subgraph output - // NOTE: Ignoring edge case where a node output is an overall graph output AND a node input - outputs.erase(def->Name()); - } - - if (graph.IsInitializedTensor(def->Name())) { - initializers.insert(def); - } - } - } - }; - - auto add_node = [&](const Node& node) { - indexed_sub_graph->nodes.push_back(node.Index()); - add_inputs(node.InputDefs()); - add_inputs(node.ImplicitInputDefs()); - - for (const auto* def : node.OutputDefs()) { - outputs.insert(def->Name()); - } - }; - - // Add nodes - for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { - const OrtNode* ort_node = nodes[node_idx]; - const EpNode* ep_node = EpNode::ToInternal(ort_node); - if (ep_node == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph."); - } - add_node(ep_node->GetInternalNode()); - } - - // Add initializers - for (auto& initializer : initializers) { - metadef->constant_initializers.push_back(initializer->Name()); - } - - // Add outputs - for (auto& output : outputs) { - metadef->outputs.push_back(output); - } - - indexed_sub_graph->SetMetaDef(std::move(metadef)); - auto graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); - - std::unique_ptr result; - ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result)); - - *dst_graph = result.release(); - - return nullptr; - API_IMPL_END -} - // // OrtNode // @@ -3030,11 +2922,10 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetNumSubgraphs, _In_ const OrtNode* node, _Ou } ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, - _Out_writes_opt_(num_subgraphs) const char** attribute_names) { + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs) { API_IMPL_BEGIN gsl::span graphs_span(subgraphs, num_subgraphs); - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span, attribute_names)); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span)); return nullptr; API_IMPL_END } @@ -3052,23 +2943,6 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetGraph, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetEpName, _In_ const OrtNode* node, - _Outptr_result_maybenull_ const char** out) { - API_IMPL_BEGIN - if (out == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); - } - - const EpNode* ep_node = EpNode::ToInternal(node); - if (ep_node == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetEpName."); - } - - *out = ep_node->GetEpName().c_str(); - return nullptr; - API_IMPL_END -} - ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #ifdef ENABLE_TRAINING_APIS if (version >= 13 && version <= ORT_API_VERSION) @@ -3720,8 +3594,6 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ValueInfo_IsFromOuterScope, &OrtApis::Graph_GetName, &OrtApis::Graph_GetOnnxIRVersion, - &OrtApis::Graph_GetNumOperatorSets, - &OrtApis::Graph_GetOperatorSets, &OrtApis::Graph_GetNumInputs, &OrtApis::Graph_GetInputs, &OrtApis::Graph_GetNumOutputs, @@ -3731,7 +3603,6 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetNumNodes, &OrtApis::Graph_GetNodes, &OrtApis::Graph_GetParentNode, - &OrtApis::Graph_GetGraphView, &OrtApis::Node_GetId, &OrtApis::Node_GetName, &OrtApis::Node_GetOperatorType, @@ -3751,7 +3622,6 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumSubgraphs, &OrtApis::Node_GetSubgraphs, &OrtApis::Node_GetGraph, - &OrtApis::Node_GetEpName, &OrtApis::GetRunConfigEntry, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9ab927006c320..cbacbfce0740d 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -631,10 +631,6 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); -ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); -ORT_API_STATUS_IMPL(Graph_GetOperatorSets, _In_ const OrtGraph* graph, - _Out_writes_(num_operator_sets) const char** domains, - _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets); ORT_API_STATUS_IMPL(Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs); ORT_API_STATUS_IMPL(Graph_GetInputs, _In_ const OrtGraph* graph, _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs); @@ -649,8 +645,6 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); -ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, - _Outptr_ OrtGraph** subgraph); // OrtNode ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id); @@ -677,10 +671,8 @@ ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOp ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, - _Out_writes_opt_(num_subgraphs) const char** attribute_names); + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); -ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 211bf8b2d15a4..e8d62ab86f517 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -22,13 +22,7 @@ namespace onnxruntime { namespace { bool MatchesEpVendor(const OrtEpDevice* d) { - // match on vendor id if provided - uint32_t factory_vendor_id = d->ep_factory->GetVendorId(d->ep_factory); - if (factory_vendor_id != 0 && d->device->vendor_id == factory_vendor_id) { - return true; - } - - // match on vendor name + // TODO: Would be better to match on Id. Should the EP add that in EP metadata? return d->device->vendor == d->ep_vendor; } diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index f7d5cdb98aa1d..0172902bdf4e2 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -1001,53 +1001,4 @@ struct BlockedQuantizeLinear { #endif -/** - * @brief Run MlasDequantizeLinear in parallel, with provided thread pool - */ - -template -void ParDequantizeLinearStd(const InputQuantType* input, - float* output, - size_t num_elems, - float scale, - InputQuantType zero_point, - concurrency::ThreadPool* thread_pool) { - constexpr std::ptrdiff_t block_size = 128; - const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; - const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), - static_cast(block_size * sizeof(float)), - static_cast(block_size) * 2.0}; - concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - auto begin_idx = begin * block_size; - auto end_idx = std::min(static_cast(num_elems), end * block_size); - MlasDequantizeLinear(&(input[begin_idx]), &(output[begin_idx]), end_idx - begin_idx, scale, zero_point); - }); -} - -// Note: this doesn't use MLAS kernel. There are currently no MLAS kernels for fp16 QuantizeLinear or DequantizeLinear. -template -void ParDequantizeLinearStd(const InputQuantType* input, - MLFloat16* output, - size_t num_elems, - MLFloat16 scale, - InputQuantType zero_point, - concurrency::ThreadPool* thread_pool) { - constexpr std::ptrdiff_t block_size = 128; - const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; - const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), - static_cast(block_size * sizeof(MLFloat16)), - static_cast(block_size) * 2.0}; - - const int32_t zp_s32 = static_cast(zero_point); - const float sc_f32 = scale.ToFloat(); - - concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - auto begin_idx = begin * block_size; - auto end_idx = std::min(static_cast(num_elems), end * block_size); - for (; begin_idx != end_idx; ++begin_idx) { - output[begin_idx] = MLFloat16(static_cast(static_cast(input[begin_idx]) - zp_s32) * sc_f32); - } - }); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/util/thread_utils.h b/onnxruntime/core/util/thread_utils.h index 0b99723b2c75b..d63d620dbc321 100644 --- a/onnxruntime/core/util/thread_utils.h +++ b/onnxruntime/core/util/thread_utils.h @@ -19,13 +19,7 @@ struct OrtThreadPoolParams { bool auto_set_affinity = false; // If it is true, the thread pool will spin a while after the queue became empty. -#if !defined(ORT_CLIENT_PACKAGE_BUILD) bool allow_spinning = true; -#else - // default allow_spinning to false for ORT builds targeting client/on-device workloads, - // to reduce CPU utilization and improve power efficiency. - bool allow_spinning = false; -#endif // It it is non-negative, thread pool will split a task by a decreasing block size // of remaining_of_total_iterations / (num_of_threads * dynamic_block_base_) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index e3303dac6c8c5..9a297e451213a 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -42,7 +42,7 @@ def __init__(self, **data: dict[str, Any]): for k, v in data.items(): if not isinstance(k, str): raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.") - if k != "axis" and not isinstance(v, (int, str, np.ndarray, float)): + if k != "axis" and not isinstance(v, (int, str, np.ndarray)): raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.") if k == "axis" and not isinstance(v, int) and v is not None: raise TypeError(f"Axis value must be an int or None, not {type(v)}.") diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index 319c5aa468f7e..fbeae39c39d21 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -86,7 +86,6 @@ "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, "BatchNormalization": QDQNormalization, - "TopK": QDQDirect8BitOp, } diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index 8711e368cd1e6..fe93f5cd358bf 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -269,48 +269,42 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): attention_last_node = reshape_qkv add_qk = "" - causal_mask_nodes_1 = None - causal_mask_nodes_2 = None if add_mask is not None: - if add_mask.input[1] == "attention_mask": + # 4D Add after Q x K' + add_qk_nodes = self.model.match_parent_path( + add_mask, + [ + "Where", + "Sub", + "Cast", + "Expand", + "Unsqueeze", + "Unsqueeze", + "Reshape", + "Reshape", + "Cast", + ], + [1, 2, 1, 0, 0, 0, 0, 0, 0], + ) + if add_qk_nodes is not None: add_qk = add_mask.input[1] else: - # 4D Add after Q x K' - add_qk_nodes = self.model.match_parent_path( + # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path + # of computing causal mask. + causal_mask_nodes_1 = self.model.match_parent_path( add_mask, - [ - "Where", - "Sub", - "Cast", - "Expand", - "Unsqueeze", - "Unsqueeze", - "Reshape", - "Reshape", - "Cast", - ], - [1, 2, 1, 0, 0, 0, 0, 0, 0], + ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0, 0], ) - if add_qk_nodes is not None: - add_qk = add_mask.input[1] - else: - # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path - # of computing causal mask. - causal_mask_nodes_1 = self.model.match_parent_path( - add_mask, - ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0, 0], - ) - # If the model is exported with batch_size == 1, there is no Concat node - causal_mask_nodes_2 = self.model.match_parent_path( - add_mask, - ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0], - ) - - if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None: - logger.debug("fuse_attention: failed to match causal mask subgraph") - return + # If the model is exported with batch_size == 1, there is no Concat node + causal_mask_nodes_2 = self.model.match_parent_path( + add_mask, + ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0], + ) + if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None: + logger.debug("fuse_attention: failed to match causal mask subgraph") + return new_node = self.create_attention_node( mask_index=None, @@ -326,7 +320,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): output=attention_last_node.output[0], add_qk_str=add_qk, scale=None, - causal=(causal_mask_nodes_1 is not None) or (causal_mask_nodes_2 is not None), + causal=(add_mask is not None), ) if new_node is None: logger.debug("fuse_attention: failed to create fused node") diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index e16957eab80a1..6bd698f8b75b4 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,7 +1,7 @@ onnxscript>=0.2.3 optimum>=1.14.1 optree -transformers==4.52.1 +transformers==4.48.0 torch>=2.7.0 onnx==1.17.0 datasets>=2.8.0 diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index e092285d57358..ac696ff3788aa 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -410,7 +410,7 @@ def export_onnx_models( precision == Precision.FLOAT16, model.config.encoder_attention_heads, model.config.d_model, - model.config.decoder_layers, + model.config.num_hidden_layers, use_external_data_format, use_gpu=use_gpu, provider=provider, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 37fc72cd26e07..f1758cc52280f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,5 +1,5 @@ torch>=2.7.0 -transformers==4.52.3 +transformers>=4.52.3 openai-whisper==20240927 ffmpeg-python datasets diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index e10e616d35d38..fadf271ae913b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -187,7 +187,7 @@ def input_names(self): *list( chain.from_iterable( (f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}") - for i in range(self.config.decoder_layers) + for i in range(self.config.num_hidden_layers) ) ), ] @@ -205,7 +205,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.decoder_layers) + for i in range(self.config.num_hidden_layers) ) ), ] @@ -214,7 +214,8 @@ def output_names(self): "logits", *list( chain.from_iterable( - (f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers) + (f"present_key_self_{i}", f"present_value_self_{i}") + for i in range(self.config.num_hidden_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index cd81edc1001be..26dc3aee7018b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -127,7 +127,7 @@ def output_names(self): *list( chain.from_iterable( (f"present_key_cross_{i}", f"present_value_cross_{i}") - for i in range(self.config.decoder_layers) + for i in range(self.config.num_hidden_layers) ) ), ] @@ -143,7 +143,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.decoder_layers) + for i in range(self.config.num_hidden_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index a236c4da1738e..f66aa22eb0972 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -763,7 +763,7 @@ def optimize_onnx( is_float16: bool, num_attention_heads: int, hidden_size: int, - num_decoder_layers: int, + num_layers: int, use_external_data_format: bool = False, use_gpu: bool = False, provider: str = "cpu", @@ -801,7 +801,7 @@ def optimize_onnx( m = add_cache_indirection_to_mha(m, past_seq_len_name) if output_qk: - m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_decoder_layers, 2))) + m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_layers, 2))) m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py index 8937fea900d14..0b0882eface72 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py @@ -94,14 +94,14 @@ def get_sample_past_key_values( torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.decoder_layers) + for _ in range(config.num_hidden_layers) ] cross_attention_kv_caches = [ ( torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.decoder_layers) + for _ in range(config.num_hidden_layers) ] return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches) @@ -187,7 +187,7 @@ def get_sample_QKs( # noqa: N802 torch.rand( batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype ) - for _ in range(config.decoder_layers) + for _ in range(config.num_hidden_layers) ] return QKs diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py index 4dd5d7de1752b..a7c0d3538b8da 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py @@ -156,7 +156,7 @@ def input_names(self): "alignment_heads", "sot_sequence_length", "segment_length", - *[f"cross_qk_{i}" for i in range(self.config.decoder_layers)], + *[f"cross_qk_{i}" for i in range(self.config.num_hidden_layers)], ] return input_names diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index 44b3f9a213abf..b498c40079f48 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -226,7 +226,7 @@ OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) { /*static*/ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) noexcept { + OrtEpGraphSupportInfo* graph_support_info) { ExampleEp* ep = static_cast(this_ptr); size_t num_nodes = 0; @@ -290,7 +290,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { + _Out_writes_(count) OrtNode** ep_context_nodes) { ExampleEp* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; @@ -328,12 +328,6 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); - const char* ep_name = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetEpName(fused_nodes[0], &ep_name)); - if (std::strncmp(ep_name, "example_ep", 11) != 0) { - return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on"); - } - // Associate the name of the fused node with our MulKernel. const char* fused_node_name = nullptr; RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); @@ -360,7 +354,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const /*static*/ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos) noexcept { + size_t num_node_compute_infos) { (void)this_ptr; for (size_t i = 0; i < num_node_compute_infos; i++) { delete node_compute_infos[i]; diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index dfebcc52a0caf..b8c63f39438ba 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -31,14 +31,14 @@ class ExampleEp : public OrtEp, public ApiPtrs { private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; static OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) noexcept; + OrtEpGraphSupportInfo* graph_support_info); static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; + _Out_writes_(count) OrtNode** ep_context_nodes); static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos) noexcept; + size_t num_node_compute_infos); OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 19a44008b8c97..d4895102b0bf1 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -14,7 +14,6 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; - GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; @@ -88,12 +87,6 @@ const char* ORT_API_CALL ExampleEpFactory::GetVendorImpl(const OrtEpFactory* thi return factory->vendor_.c_str(); } -/*static*/ -uint32_t ORT_API_CALL ExampleEpFactory::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->vendor_id_; -} - /*static*/ const char* ORT_API_CALL ExampleEpFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 72fa1c1301841..fda77f12c4814 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -21,7 +21,6 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; - static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; @@ -54,7 +53,6 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name - const uint32_t vendor_id_{0xB357}; // EP vendor ID const std::string ep_version_{"0.1.0"}; // EP version // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 4c3f9e8dd4dbd..7b77ca8c69225 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -527,20 +527,18 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop if (std::is_same_v) { #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); - RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); - RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); - RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_WEBGPU execution_providers.push_back(DefaultWebGpuExecutionProvider()); - RunTest(opts, std::move(execution_providers)); #endif + + RunTest(opts, std::move(execution_providers)); } else { #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 17e829e37f729..60498e6510ec2 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -1,24 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include #include #include #include #include -#include #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/onnxruntime_cxx_api.h" -#include "core/graph/ep_api_types.h" -#include "core/graph/graph_proto_serializer.h" - -#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL -#include "core/providers/utils/ort_graph_to_proto.h" #include "test/ep_graph/test_ep_graph_utils.h" #include "test/util/include/api_asserts.h" @@ -34,7 +26,6 @@ namespace test { // forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent // to a graph represented by the internal ORT GraphViewer class. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph); -static void Check_Graph_GetSubgraph(const OrtGraph& api_graph); // // Tests @@ -77,178 +68,6 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } -TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { - // The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test. - // The model consists of a graph with subgraphs nested across three levels. - // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). - auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph_v2.onnx")); - ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; - - CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); -} - -static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - Ort::SessionOptions sess_options; - Ort::Session session(*ort_env, model_path, sess_options); - - std::vector input_shape = {1, 1, 28, 28}; - std::vector input_data(28 * 28, 0.5f); - std::vector ort_inputs; - std::vector ort_input_names; - - // Add 'Input3' - ort_inputs.emplace_back(Ort::Value::CreateTensor( - memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); - ort_input_names.push_back("Input3"); - - // Run session and get outputs - std::array output_names{"Plus214_Output_0"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); - - // Check output type and number of elements. - Ort::Value& ort_output = ort_outputs[0]; - auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); - size_t num_output_elems = output_type_shape.GetElementCount(); - - ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - ASSERT_EQ(num_output_elems, 10); - - // Return output data. - const float* output_values = ort_output.GetTensorData(); - output_data.assign(output_values, output_values + num_output_elems); -} - -// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. -// Checks that the outputs of the serialized and original models are identical. -TEST(EpGraphTest, SerializeToProto_Mnist) { - const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/mnist.onnx"); - const ORTCHAR_T* serialized_model_path = ORT_TSTR("mnist_serialized.onnx"); - std::filesystem::remove(serialized_model_path); - - { - auto test_graph = TestGraph::Load(original_model_path); - ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; - - // Serialize OrtGraph to GraphProto. Save initializers to external file. - std::string ext_ini_file_path = "mnist_serialized.bin"; - std::filesystem::remove(ext_ini_file_path); - std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); - auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, - const void* data, size_t bytes, - bool& is_external, std::string& location, - int64_t& offset) -> Ort::Status { - // OrtValueInfo* could be used to query initializer's name, type, shape, - // node consumers, etc. - (void)value_info; - - if (bytes <= 127) { - is_external = false; // Keep small initializers stored inside the TensorProto. - return Ort::Status{nullptr}; - } - - offset = ext_ini_ofs.tellp(); - location = ext_ini_file_path; - ext_ini_ofs.write(static_cast(data), bytes); - ext_ini_ofs.flush(); - is_external = true; // True if is external initializer. - - return Ort::Status{nullptr}; - }; - - ONNX_NAMESPACE::ModelProto model_proto; - OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, handle_initializer_data); - - std::ofstream ofs(serialized_model_path, std::ios::binary); - model_proto.SerializeToOstream(&ofs); - ofs.flush(); - - ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); - ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); - } - - // Compare output of the original and serialized models. Should be identical. - std::vector output_original; - std::vector output_serialized; - - RunMNISTModel(original_model_path, output_original); - RunMNISTModel(serialized_model_path, output_serialized); - - EXPECT_EQ(output_serialized, output_original); -} - -static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - Ort::SessionOptions sess_options; - Ort::Session session(*ort_env, model_path, sess_options); - - std::vector input_shape = {1}; - std::vector ort_inputs; - std::vector ort_input_names; - - // Add 'if_cond_input' - ort_inputs.emplace_back(Ort::Value::CreateTensor( - memory_info, &input_cond, 1, input_shape.data(), input_shape.size())); - ort_input_names.push_back("if_cond_input"); - - // Run session and get outputs - std::array output_names{"if_cond_output"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); - - // Check output type and number of elements. - Ort::Value& ort_output = ort_outputs[0]; - auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); - size_t num_output_elems = output_type_shape.GetElementCount(); - - ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - ASSERT_EQ(num_output_elems, 1); - - // Return output data. - const float* output_values = ort_output.GetTensorData(); - output_data.assign(output_values, output_values + num_output_elems); -} - -// Test serializing an OrtGraph to GraphProto. The model has 3 layers of nested subgraphs. -// Checks that the outputs of the serialized and original models are identical. -TEST(EpGraphTest, SerializeToProto_3LayerSubgraphs) { - const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/three_layer_nested_subgraph.onnx"); - const ORTCHAR_T* serialized_model_path = ORT_TSTR("three_layer_nested_subgraph_serialized.onnx"); - std::filesystem::remove(serialized_model_path); - - { - auto test_graph = TestGraph::Load(original_model_path); - ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; - - // Serialize OrtGraph to ModelProto (all initializers stored within TensorProtos). - ONNX_NAMESPACE::ModelProto model_proto; - OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto); - - std::ofstream ofs(serialized_model_path, std::ios::binary); - model_proto.SerializeToOstream(&ofs); - ofs.flush(); - - ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); - } - - // Compare output of the original and serialized models. Should be identical. - std::vector output_original; - std::vector output_serialized; - - { - Run3LayerModel(original_model_path, true, output_original); - Run3LayerModel(serialized_model_path, true, output_serialized); - EXPECT_EQ(output_serialized, output_original); - } - - { - Run3LayerModel(original_model_path, false, output_original); - Run3LayerModel(serialized_model_path, false, output_serialized); - EXPECT_EQ(output_serialized, output_original); - } -} - // // Utils for traversing an OrtGraph and checking against GraphViewer. // @@ -488,48 +307,6 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span nodes(num_nodes); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); - - // Select a half of nodes to create a OrtGraph - size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); - std::vector selected_nodes(num_selected_nodes); - - for (size_t i = 0; i < num_selected_nodes; i++) { - selected_nodes[i] = nodes[i]; - } - - OrtGraph* sub_graph; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); - - // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. - // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. - const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer(); - std::unique_ptr model = std::make_unique(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger()); - auto model_proto = std::make_unique(model->ToProto()); - GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - const char* graph_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); - std::string name = graph_name; - name += "_half.onnx"; - - // Dump the graph for debugging - // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); - // model_proto->SerializeToOstream(&dump); - - ort_api.ReleaseGraph(sub_graph); -} - // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { @@ -693,10 +470,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } // Check node subgraphs - std::unordered_map> node_subgraphs_map = - node->GetAttributeNameToSubgraphMap(); + std::vector> node_subgraphs = node->GetSubgraphs(); - if (!node_subgraphs_map.empty()) { + if (!node_subgraphs.empty()) { // Check node's implicit inputs to its subgraph nodes. const auto implicit_input_node_args = node->ImplicitInputDefs(); @@ -713,34 +489,18 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Recursively check subgraphs. size_t api_num_node_subgraphs = 0; ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumSubgraphs(api_node, &api_num_node_subgraphs)); - ASSERT_EQ(api_num_node_subgraphs, node_subgraphs_map.size()); std::vector api_node_subgraphs(api_num_node_subgraphs); - std::vector api_subgraph_attr_names(api_num_node_subgraphs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size(), - api_subgraph_attr_names.data())); - - for (const auto& [attr_name, subgraph] : node_subgraphs_map) { - // find index of this subgraph. - size_t api_subgraph_idx = api_num_node_subgraphs; - for (size_t subgraph_idx = 0; subgraph_idx < api_num_node_subgraphs; subgraph_idx++) { - if (api_subgraph_attr_names[subgraph_idx] == attr_name) { - api_subgraph_idx = subgraph_idx; - break; - } - } - ASSERT_NE(api_subgraph_idx, api_num_node_subgraphs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size())); + + for (size_t subgraph_idx = 0; subgraph_idx < node_subgraphs.size(); subgraph_idx++) { + auto subgraph_viewer = std::make_unique(*node_subgraphs[subgraph_idx]); + const OrtGraph* api_subgraph = api_node_subgraphs[subgraph_idx]; - // Recursively check the subgraph - auto subgraph_viewer = std::make_unique(*subgraph); - const OrtGraph* api_subgraph = api_node_subgraphs[api_subgraph_idx]; CheckGraphCApi(*subgraph_viewer, *api_subgraph); } } } - - // Check creating an OrtGraph from a subset of nodes in an OrtGraph - Check_Graph_GetSubgraph(api_graph); } } // namespace test diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc index 3b3bc4c6da911..b7743e65061de 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc @@ -30,7 +30,6 @@ std::unique_ptr TestGraph::Load(const ORTCHAR_T* model_path) { const OrtGraph& TestGraph::GetOrtGraph() const { return *api_graph; } const GraphViewer& TestGraph::GetGraphViewer() const { return graph_viewer; } -const Model& TestGraph::GetModel() const { return *model; } static Status GetInputIndices(const Node& consumer_node, const std::string& name, /*out*/ std::vector& indices) { diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.h b/onnxruntime/test/ep_graph/test_ep_graph_utils.h index 2ce107cf734c6..b0ed825f21d71 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.h +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.h @@ -28,7 +28,6 @@ class TestGraph { static std::unique_ptr Load(const ORTCHAR_T* model_path); const OrtGraph& GetOrtGraph() const; const GraphViewer& GetGraphViewer() const; - const Model& GetModel() const; private: std::shared_ptr model; diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 4c5dcd2bd7580..18bc9cf05b36d 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -36,7 +36,7 @@ struct TestOrtEp : ::OrtEp, ApiPtrs { // Individual tests should fill out the other function pointers as needed. } - static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) noexcept { + static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) { constexpr const char* ep_name = "TestOrtEp"; return ep_name; } @@ -50,7 +50,7 @@ struct TestOrtEpFactory : ::OrtEpFactory { ReleaseEp = ReleaseEpImpl; } - static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) { delete static_cast(ep); } }; @@ -125,7 +125,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { } { - auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) noexcept -> ::OrtStatus* { + auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { *preferred_data_layout = OrtEpDataLayout::OrtEpDataLayout_NCHW; return nullptr; }; @@ -135,7 +135,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { #if !defined(ORT_NO_EXCEPTIONS) { - auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) noexcept -> ::OrtStatus* { + auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { *preferred_data_layout = static_cast(-1); return nullptr; }; @@ -144,7 +144,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { } { - auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) noexcept -> ::OrtStatus* { + auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) -> ::OrtStatus* { auto* test_ort_ep = static_cast(this_ptr); return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "I can't decide what data layout I prefer."); }; @@ -167,7 +167,7 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { const char* /*node_domain*/, const char* node_op_type, OrtEpDataLayout target_data_layout, - int* should_convert) noexcept -> ::OrtStatus* { + int* should_convert) -> ::OrtStatus* { EXPECT_EQ(target_data_layout, OrtEpDataLayout::OrtEpDataLayout_NHWC); if (node_op_type == std::string_view{"Conv"}) { @@ -201,7 +201,7 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { const char* /*node_domain*/, const char* /*node_op_type*/, OrtEpDataLayout /*target_data_layout*/, - int* /*should_convert*/) noexcept -> ::OrtStatus* { + int* /*should_convert*/) -> ::OrtStatus* { auto* test_ort_ep = static_cast(this_ptr); return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "To convert to NHWC or not to convert to NHWC..."); diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp index ea36383f70621..65822eb294d7d 100644 --- a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp +++ b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp @@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) { std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory // warming up run - MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); } free(ptr.underlying_buffer); diff --git a/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp deleted file mode 100644 index b994981364947..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -template -class MlasDequantizeLinearTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - - void GenerateReference(const QuantInt* Input, float* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { - int32_t ZeroPointS32 = static_cast(ZeroPoint); - - for (size_t n = 0; n < N; n++) { - OutputReference[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; - } - } - - void Test(size_t N) { - QuantInt* Input = BufferInput.GetBuffer(N); - float* Output = BufferOutput.GetBuffer(N); - float* OutputReference = BufferOutputReference.GetBuffer(N); - - std::default_random_engine generator(static_cast(N)); - - std::uniform_real_distribution min_gen(-10.f, -10e-3f); - float MinimumValue = min_gen(generator); - - std::uniform_real_distribution max_gen(10e-3f, 10.f); - float MaximumValue = max_gen(generator); - - float Scale = (MaximumValue - MinimumValue) / 512.f; - - std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), - std::numeric_limits::max()); - QuantInt ZeroPoint = static_cast(zp_distribution(generator)); - - for (size_t n = 0; n < N; n++) { - Input[n] = static_cast(zp_distribution(generator)); - } - - GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); - MlasDequantizeLinear(Input, Output, N, Scale, ZeroPoint); - - for (size_t n = 0; n < N; n++) { - ASSERT_EQ(Output[n], OutputReference[n]) << ", size=" << N << ", index=" << n; - } - } - - public: - static const char* GetTestSuiteName() { - if constexpr (std::is_same_v) { - return "DequantizeLinearS8"; - } else { - return "DequantizeLinearU8"; - } - } - - void ExecuteShort(void) override { - for (size_t n = 1; n <= 512; n++) { - Test(n); - } - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - size_t count = 0; - if (is_short_execute) { - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - } - return count; -}); diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp index 4d7a45143b311..041b6c61cd5bf 100644 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp @@ -152,7 +152,7 @@ class MlasSoftmaxTest : public MlasTestBase { } void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 1e-6f; @@ -206,7 +206,7 @@ class MlasSoftmaxTest : public MlasTestBase { InputReference[nd] = Input[nd].ToFloat(); } - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 5e-3f; diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 215203b31f49c..649c9af7cc80b 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -61,8 +61,7 @@ TEST(SoftmaxOperator, webgpu_nan) { test.AddOutput("Y", dimensions, expected_result); // explicitly disable for EPs that do not handle NaN - test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCpuExecutionProvider, kCoreMLExecutionProvider, kDmlExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider, kCoreMLExecutionProvider}); } #endif diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 8fdbf0060eaa0..4e7a6356a5129 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -33,32 +33,6 @@ TEST(DequantizeLinearOpTest, Int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// scalar zero & scale with uint8 (large enough input to execute MLAS vectorized loop) -TEST(DequantizeLinearOpTest, Uint8_Large) { - OpTester test("DequantizeLinear", 10); - std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs - test.AddInput("x", dims, std::vector(1039, 1)); - test.AddInput("x_scale", {}, {1.0f}); - test.AddInput("x_zero_point", {}, {1}); - test.AddOutput("y", dims, std::vector(1039, 0.0f)); - // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. - // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); -} - -// scalar zero & scale with int8 (large enough input to execute MLAS vectorized loop) -TEST(DequantizeLinearOpTest, Int8_Large) { - OpTester test("DequantizeLinear", 10); - std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs - test.AddInput("x", dims, std::vector(1039, 1)); - test.AddInput("x_scale", {}, {1.0f}); - test.AddInput("x_zero_point", {}, {1}); - test.AddOutput("y", dims, std::vector(1039, 0.0f)); - // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. - // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); -} - // scalar zero & scale with int4 TEST(DequantizeLinearOpTest, Int4) { OpTester test("DequantizeLinear", 21); diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc index e6d113e1e4dca..895c8ab3e53e4 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc @@ -235,16 +235,5 @@ TEST(ScatterNDOpTest, ScatterND_18_max) { test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } -// Test for ScatterND with empty indices - output should be same as input -TEST(ScatterNDOpTest, ScatterND_empty_indices) { - // Test with float data type and minimal empty case - OpTester test1("ScatterND", 11); - test1.AddInput("data", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - test1.AddInput("indices", {0, 1}, {}); // Empty indices tensor - no indices to process - test1.AddInput("updates", {0, 3}, {}); // Empty updates tensor - test1.AddOutput("output", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); // Same as input - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); -} - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 739e39a6975e2..4febfe7ba836d 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -509,11 +509,6 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB Ort::ModelCompilationOptions compile_options(*ort_env, session_options); compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); - std::string target_dir = "./testdata/"; - std::string model_name = "test_model_in_mem.onnx"; - auto pos = model_name.rfind(".onnx"); - std::string bin_file_name = model_name.substr(0, pos) + "_qnn.bin"; - compile_options.SetEpContextBinaryInformation(ToWideString(target_dir).c_str(), ToWideString(model_name).c_str()); compile_options.SetEpContextEmbedMode(false); // Compile the model. @@ -524,18 +519,12 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB ASSERT_TRUE(output_model_buffer != nullptr); ASSERT_TRUE(output_model_buffer_size > 0); - ASSERT_TRUE(std::filesystem::exists(target_dir + bin_file_name)) << "expected context binary file should exist"; - // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); - // Add session option "ep.context_file_path" so that the session can use it to locate the [model_name]_qnn.bin file - std::string ctx_model = target_dir + model_name; - session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ctx_model.c_str()); // Should be able to create a session with the compiled model and the original session options. EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options))); - std::filesystem::remove(target_dir + bin_file_name); allocator.Free(output_model_buffer); } } @@ -1660,6 +1649,7 @@ static void DumpModelWithSharedCtx(ProviderOptions provider_options, Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); } +#if defined(__aarch64__) || defined(_M_ARM64) static void GetModelInputNames(const std::string& model_path, std::vector& input_names, std::vector& output_names, @@ -1679,6 +1669,7 @@ static void GetModelInputNames(const std::string& model_path, output_names.push_back(output->Name()); } } +#endif // 1. Create 2 QDQ models // 2. Initialize 2 Ort sessions which share the same QNN EP from these 2 QDQ models @@ -2003,73 +1994,6 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) { }); } } - -TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { - ProviderOptions provider_options; - provider_options["backend_type"] = "htp"; - provider_options["offload_graph_io_quantization"] = "0"; - - Ort::SessionOptions so; - so.AppendExecutionProvider("QNN", provider_options); - so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); - - Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx"), so); - - std::vector input_names; - std::vector output_names; - GetModelInputNames("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx", input_names, output_names, - DefaultLoggingManager().DefaultLogger()); - - // Run sessions - // prepare input - std::vector input_dim{3, 4}; - std::vector input_value(3 * 4, 0.0f); - Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); - std::vector ort_inputs; - std::vector input_names_c; - for (size_t i = 0; i < input_names.size(); ++i) { - auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(), - input_dim.data(), input_dim.size()); - ort_inputs.push_back(std::move(input_tensor)); - input_names_c.push_back(input_names[i].c_str()); - } - std::vector output_names_c; - for (size_t i = 0; i < output_names.size(); ++i) { - output_names_c.push_back(output_names[i].c_str()); - } - - auto ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), - output_names_c.data(), 1); - - const char* const workload_type[] = {"ep.dynamic.workload_type"}; - const char* const efficient_type[] = {"Efficient"}; - const char* const default_type[] = {"Default"}; - - // Test Efficient & Default options - session.SetEpDynamicOptions(workload_type, efficient_type, 1); - ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), - output_names_c.data(), 1); - - session.SetEpDynamicOptions(workload_type, default_type, 1); - ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), - output_names_c.data(), 1); - - // Test invalid EP dynamic option and invalid workload type - const char* const dne[] = {"DNE"}; - try { - session.SetEpDynamicOptions(workload_type, dne, 1); - FAIL() << "Expected exception to be thrown for workload type DNE but was set successfully"; - } catch (const std::exception& e) { - EXPECT_STREQ("Invalid EP Workload Type.", e.what()); - } - - try { - session.SetEpDynamicOptions(dne, efficient_type, 1); - FAIL() << "Expected exception to be thrown for dynamic option DNE but was set successfully"; - } catch (const std::exception& e) { - EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); - } -} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 4c0a53e83e274..85f8250f70fc5 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1254,38 +1254,6 @@ TEST_F(QnnHTPBackendTests, GridSample_U16_Nearest) { true); } -// Test QDQ GridSample with `linear` mode on opset 20+. -TEST_F(QnnHTPBackendTests, GridSample_Linear_ZerosPadding) { - RunQDQOpTest("GridSample", - {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), - TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, - {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "zeros")}, - /*opset_version=*/20, - /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); -} - -TEST_F(QnnHTPBackendTests, GridSample_Linear_AlignCorners_BorderPadding) { - RunQDQOpTest("GridSample", - {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), - TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, - {utils::MakeAttribute("align_corners", static_cast(1)), - utils::MakeAttribute("mode", "linear"), - utils::MakeAttribute("padding_mode", "border")}, - /*opset_version=*/20, - /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); -} - -TEST_F(QnnHTPBackendTests, GridSample_Linear_ReflectionPadding_U16) { - RunQDQOpTest("GridSample", - {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), - TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, - {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "reflection")}, - /*opset_version=*/21, - /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, - /*op_domain=*/kOnnxDomain, - /*use_contrib_qdq=*/true); -} - // Test QDQ GridSample with reflection padding mode // Inaccuracy detected for output 'output', element 2. // Output quant params: scale=0.024269860237836838, zero_point=0. diff --git a/onnxruntime/test/python/quantization/test_op_topk.py b/onnxruntime/test/python/quantization/test_op_topk.py deleted file mode 100644 index 1fdd0c987d1e8..0000000000000 --- a/onnxruntime/test/python/quantization/test_op_topk.py +++ /dev/null @@ -1,103 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import unittest - -import numpy as np -from onnx import TensorProto, helper, save -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type - -from onnxruntime.quantization import QuantFormat, QuantType, quantize_static - - -class TestTopKModel(unittest.TestCase): - @staticmethod - def construct_model(model_path, input_shape, axis_attr, k): - input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape) - k_tensor = helper.make_tensor("k", TensorProto.INT64, [1], [k]) - output_shape = input_shape[:] - output_shape[axis_attr] = k - output_values = helper.make_tensor_value_info("values", TensorProto.FLOAT, [1, k]) - output_indices = helper.make_tensor_value_info("indices", TensorProto.INT64, [1, k]) - - node = helper.make_node( - "TopK", inputs=["input", "k"], outputs=["values", "indices"], name="topk_node", axis=axis_attr - ) - - graph = helper.make_graph( - [node], - "quant_topk_op_test", - [input_tensor], - [output_values, output_indices], - initializer=[k_tensor], - ) - - model = helper.make_model( - graph, opset_imports=[helper.make_opsetid("", 16), helper.make_opsetid("com.microsoft", 1)] - ) - save(model, model_path) - - def quantize_topk_test(self, activation_type, weight_type, extra_options={}): # noqa: B006 - model_fp32_path = "topk_fp32.onnx" - input_shape = [1, 10] - axis = 1 - k = 3 - self.construct_model(model_fp32_path, input_shape, axis, k) - - input_data_list = [ - {"input": np.array([[1.8, 2.5, -5.9, 5.2, 4.1, 7.3, 0.2, -0.5, 0.845, 3.9]], dtype=np.float32)} - ] - data_reader = TestDataFeeds(input_data_list) - - activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" - weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" - model_qdq_path = f"topk_{activation_type_str}{weight_type_str}_{'QNoInCk' if extra_options['ForceQuantizeNoInputCheck'] else 'NoQNoInCk'}_qdq.onnx" - - # Verify QDQ mode - data_reader.rewind() - quantize_static( - model_fp32_path, - model_qdq_path, - data_reader, - quant_format=QuantFormat.QDQ, - activation_type=activation_type, - weight_type=weight_type, - extra_options=extra_options, - ) - qdqnode_counts = ( - { - "TopK": 1, - "QuantizeLinear": 2, - "DequantizeLinear": 2, - } - if extra_options["ForceQuantizeNoInputCheck"] - else { - "TopK": 1, - "QuantizeLinear": 0, - "DequantizeLinear": 0, - } - ) - check_op_type_count(self, model_qdq_path, **qdqnode_counts) - qnode_io_qtypes = { - "QuantizeLinear": [ - ["i", 2, activation_proto_qtype], - ["o", 0, activation_proto_qtype], - ] - } - check_qtype_by_node_type(self, model_qdq_path, qnode_io_qtypes) - data_reader.rewind() - check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) - - def test_quantize_topk_u8u8(self): - self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": True}) - - def test_quantize_topk_u8u8_no_force_quantize_no_input_check(self): - self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": False}) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx b/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx deleted file mode 100644 index 34cf26c13d3fc98f8a97aa3f9999e3d99e5bf847..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7729 zcmeI1%Wl&^6hLEFY1~UGX5^MANT3E)mAtT%B7%r2p=KFTgkaarXp*?Ki9O@mV_qv( zhz&fg>%mqFw?GGRPD z6iMVIX>Gab?Cdy=_J>{gs6jd4aVF0ilS#>)A&nF9(&+^(g-Xct2SVJCyL*EHZBmg* zwT3ooH(m^b_z8RKB~O)b+Nf__>R@5;j>$l9`zBPpI1NI<*S~z*et4p3?dyFJIZ@D0 zL@Ev?eAZxw2H4n>(o;?dP8;-i_=>*vf+JsoHQApVTY?g-DHp~oB9;zG)gAfdKKD|e z#U6chVg0q=WYkyAUu+XrcotFLV}rD+pJ@7|twWeA6wFcF+wFZO_p^{TTP<>@FhB(@ z534&KIuGLd%<=kiF%Q0K@D~X)12;dDFf~M$nym-5TbFW2RjNA*fPYCUfrtg19wjXH zZQrm=tuv*&`>a%Y|9FwNO><3W;Eoh5_OidP8dl-WWU{-btBZ66Wi1vBj3>qu89)ZE zG6VLHEp@u=s$U>nhuiw&DIl29N<{02x3AkO5=>89)Y*0b~FfKn9QjWB?hk zKk4}&u9;Q5?oaK1W8~o8xByFPP&G7SL4}liO!j>!lcm%<2Hmg@?oZV=H{q_Defwgz XZm5cGv7%^tn=mTw{Yh>|H`jgvT1N^V diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 7f2134b2cda4f..461c243b82212 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -13,7 +13,6 @@ import random import unittest from dataclasses import dataclass -from enum import Enum import numpy import torch @@ -39,17 +38,11 @@ ATOL = None -class Formats(Enum): +class Formats: BSNH = 0 BNSH = 1 -class QKOutputType(Enum): - NO_OUTPUT = 0 - BEFORE_SOFTMAX = 1 - AFTER_SOFTMAX = 2 - - @dataclass class Config: batch_size: int = 0 @@ -61,8 +54,6 @@ class Config: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False - has_head_sink: bool = False - qk_output: QKOutputType = QKOutputType.NO_OUTPUT @dataclass @@ -76,8 +67,6 @@ class PromptConfig: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False - has_head_sink: bool = False - qk_output: QKOutputType = QKOutputType.NO_OUTPUT # LLaMA Microsoft model @@ -162,15 +151,6 @@ def create_group_query_attention_graph_prompt( ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length - - output_names = [ - "output", - "present_key", - "present_value", - ] - if config.qk_output != QKOutputType.NO_OUTPUT: - output_names.append("output_qk") - nodes = [ helper.make_node( "GroupQueryAttention", @@ -186,9 +166,8 @@ def create_group_query_attention_graph_prompt( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", - "head_sink" if config.has_head_sink else "", ], - output_names, + ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, @@ -197,7 +176,6 @@ def create_group_query_attention_graph_prompt( rotary_interleaved=rotary_interleaved, softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, - qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -311,15 +289,6 @@ def create_group_query_attention_graph_prompt( ), ] - if config.has_head_sink: - graph_input += [ - helper.make_tensor_value_info( - "head_sink", - ort_type, - [config.num_heads], - ), - ] - graph_output = [ helper.make_tensor_value_info( "output", @@ -368,15 +337,6 @@ def create_group_query_attention_graph_prompt( ), ] - if config.qk_output != QKOutputType.NO_OUTPUT: - graph_output += [ - helper.make_tensor_value_info( - "output_qk", - ort_type, - [config.batch_size, config.num_heads, config.kv_sequence_length, config.kv_sequence_length], - ), - ] - graph = helper.make_graph( nodes, "GroupQueryAttention_Graph", @@ -405,15 +365,6 @@ def create_group_query_attention_graph_past( present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length ) - - output_names = [ - "output", - "present_key", - "present_value", - ] - if config.qk_output != QKOutputType.NO_OUTPUT: - output_names.append("output_qk") - nodes = [ helper.make_node( "GroupQueryAttention", @@ -429,9 +380,8 @@ def create_group_query_attention_graph_past( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", - "head_sink" if config.has_head_sink else "", ], - output_names, + ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, @@ -440,7 +390,6 @@ def create_group_query_attention_graph_past( rotary_interleaved=rotary_interleaved, softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, - qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -492,7 +441,6 @@ def create_group_query_attention_graph_past( [1], ), ] - if not packed: graph_input += [ helper.make_tensor_value_info( @@ -514,7 +462,6 @@ def create_group_query_attention_graph_past( ], ), ] - if rotary: graph_input += [ helper.make_tensor_value_info( @@ -551,15 +498,6 @@ def create_group_query_attention_graph_past( ), ] - if config.has_head_sink: - graph_input += [ - helper.make_tensor_value_info( - "head_sink", - ort_type, - [config.num_heads], - ), - ] - graph_output = [ helper.make_tensor_value_info( "output", @@ -588,15 +526,6 @@ def create_group_query_attention_graph_past( ), ] - if config.qk_output != QKOutputType.NO_OUTPUT: - graph_output += [ - helper.make_tensor_value_info( - "output_qk", - ort_type, - [config.batch_size, config.num_heads, config.sequence_length, present_kv_seqlen], - ), - ] - graph = helper.make_graph( nodes, "GroupQueryAttention_Graph", @@ -623,17 +552,17 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: - q: (batch_size, seqlen_q, num_heads, d) - k: (batch_size, seqlen_k, num_heads_k, d) - v: (batch_size, seqlen_k, num_heads_k, d) + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, num_heads, d = q.shape - _, seqlen_k, num_heads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, num_heads_k, d) - assert v.shape == (batch_size, seqlen_k, num_heads_k, d) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) @@ -664,7 +593,7 @@ def output_pad_fn(output_unpad): if qkvpacked: assert (query_padding_mask == key_padding_mask).all() - assert num_heads == num_heads_k + assert nheads == nheads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: @@ -785,8 +714,6 @@ def gqa_prompt_func( seqlens_k=None, position_ids=None, attention_bias=None, - head_sink=None, - output_qk=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, @@ -819,18 +746,9 @@ def gqa_prompt_func( if config.has_attention_bias: assert attention_bias is not None - if config.qk_output != QKOutputType.NO_OUTPUT: - assert output_qk is not None - if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) - - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() - ort_outputs = {} - if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -839,6 +757,10 @@ def gqa_prompt_func( "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -875,18 +797,25 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() @@ -907,26 +836,11 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - - if config.has_head_sink: - ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() - io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) - - if config.qk_output != QKOutputType.NO_OUTPUT: - ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) - io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) - - ort_session.run_with_iobinding(io_binding) - - out_qk = None - if config.qk_output != QKOutputType.NO_OUTPUT: - ort_output, present_k, present_v, out_qk = io_binding.copy_outputs_to_cpu() - else: + ort_session.run_with_iobinding(io_binding) ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - - return output, present_k, present_v, out_qk + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v def gqa_past_func( @@ -941,8 +855,6 @@ def gqa_past_func( seqlens_k=None, position_ids=None, attention_bias=None, - head_sink=None, - output_qk=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, @@ -975,18 +887,9 @@ def gqa_past_func( if config.has_attention_bias: assert attention_bias is not None - if config.qk_output != QKOutputType.NO_OUTPUT: - assert output_qk is not None - if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) - - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() - ort_outputs = {} - if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -998,6 +901,9 @@ def gqa_past_func( .cpu() .numpy(), } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -1034,6 +940,11 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -1047,6 +958,9 @@ def gqa_past_func( .cpu() .numpy(), } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -1074,26 +988,11 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - - if config.has_head_sink: - ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() - io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) - - if config.qk_output != QKOutputType.NO_OUTPUT: - ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) - io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) - - ort_session.run_with_iobinding(io_binding) - - out_qk = None - if config.qk_output != QKOutputType.NO_OUTPUT: - ort_output, present_k, present_v, out_qk = io_binding.copy_outputs_to_cpu() - else: + ort_session.run_with_iobinding(io_binding) ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - - return output, present_k, present_v, out_qk + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): @@ -1126,28 +1025,11 @@ def construct_local_mask( ) -def smooth_softmax_ref(x, head_sink): - """ - Arguments: - x: (batch_size, num_heads, seqlen_q, seqlen_k) - head_sink: (num_heads) or None - Output: - y: (batch_size, num_heads, seqlen_q, seqlen_k) - """ - assert len(x.shape) == 4 - b, n, s, t = x.shape - - if head_sink is not None: - assert len(head_sink.shape) == 1 - assert head_sink.shape[0] == x.shape[1] - sink = head_sink.reshape(1, n, 1, 1).expand(b, -1, s, -1) - else: - sink = torch.zeros(b, n, s, 1, dtype=x.dtype) - - y = torch.cat([x, sink], dim=-1) - y = torch.softmax(y, dim=-1) - y = y[..., :-1] - return y +def smooth_softmax_ref(x): + x_max = x.amax(axis=-1, keepdim=True) + x_max = torch.maximum(x_max, torch.zeros_like(x_max)) + w = torch.exp(x - x_max) + return w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-x_max)) def attention_ref( @@ -1164,17 +1046,16 @@ def attention_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, - head_sink=None, ): """ Arguments: - q: (batch_size, seqlen_q, num_heads, head_dim) - k: (batch_size, seqlen_k, num_heads_k, head_dim) - v: (batch_size, seqlen_k, num_heads_k, head_dim) + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) dropout_p: float - dropout_mask: (batch_size, num_heads, seqlen_q, seqlen_k) + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast @@ -1183,10 +1064,8 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. use_smooth_softmax: whether use smooth softmax or not - head_sink: (num_heads) or None Output: output: (batch_size, seqlen_q, nheads, head_dim) - masked_scores: (batch_size, nheads, seqlen_q, seqlen_k), before softmax attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: @@ -1206,10 +1085,8 @@ def attention_ref( scores = scores / softcap scores = scores.tanh() scores = scores * softcap - masked_scores = scores.clone() if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - masked_scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, @@ -1219,11 +1096,10 @@ def attention_ref( key_padding_mask, q.device, ) - masked_scores.masked_fill_(local_mask, 0.0) scores.masked_fill_(local_mask, float("-inf")) - if use_smooth_softmax or (head_sink is not None): - attention = smooth_softmax_ref(scores, head_sink) + if use_smooth_softmax: + attention = smooth_softmax_ref(scores) else: attention = torch.softmax(scores, dim=-1) @@ -1245,7 +1121,7 @@ def attention_ref( if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), masked_scores.to(dtype=dtype_og), attention.to(dtype=dtype_og) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) def attention_qkvpacked_ref( @@ -1257,7 +1133,6 @@ def attention_qkvpacked_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, - head_sink=None, ): return attention_ref( qkv[:, :, 0], @@ -1271,7 +1146,6 @@ def attention_qkvpacked_ref( causal=causal, reorder_ops=reorder_ops, use_smooth_softmax=use_smooth_softmax, - head_sink=head_sink, ) @@ -1312,10 +1186,6 @@ def get_custom_position_ids(batch_size, sequence_length, seqlens_k=None, past=Fa return position_ids -def get_custom_head_sink(num_heads, torch_type=torch.float16): - return torch.rand(num_heads, dtype=torch_type) - - def parity_check_gqa_prompt( config, torch_type, @@ -1378,8 +1248,6 @@ def parity_check_gqa_prompt( requires_grad=False, ) - head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None - window_size = (-1, -1) left_window_size = -1 if local: @@ -1437,20 +1305,6 @@ def parity_check_gqa_prompt( else None ) - output_qk = ( - torch.zeros( - config.batch_size, - config.num_heads, - config.kv_sequence_length, - config.q_sequence_length, - device="cpu", - dtype=torch_type, - requires_grad=False, - ) - if config.qk_output != QKOutputType.NO_OUTPUT - else None - ) - arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) @@ -1461,7 +1315,7 @@ def parity_check_gqa_prompt( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded - out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( + out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1473,7 +1327,6 @@ def parity_check_gqa_prompt( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1484,7 +1337,7 @@ def parity_check_gqa_prompt( # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v, out_qk = gqa_prompt_func( + out, present_k, present_v = gqa_prompt_func( packed_qkv, k, v, @@ -1496,8 +1349,6 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, - head_sink, - output_qk, left_window_size, past_format, True, @@ -1508,7 +1359,7 @@ def parity_check_gqa_prompt( numpy_type=numpy_type, ) else: - out, present_k, present_v, out_qk = gqa_prompt_func( + out, present_k, present_v = gqa_prompt_func( q, k, v, @@ -1520,8 +1371,6 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, - head_sink, - output_qk, left_window_size, past_format, True, @@ -1535,22 +1384,6 @@ def parity_check_gqa_prompt( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - if config.qk_output != QKOutputType.NO_OUTPUT: - out_qk_ref = ( - out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref - ) - out_qk_ref = out_qk_ref.detach().cpu().numpy() - - for batch_idx in range(config.batch_size): - total_seqlen = cache_seqlens[batch_idx] - assert numpy.allclose( - out_qk[batch_idx, :, :, :total_seqlen], - out_qk_ref[batch_idx, :, :, :total_seqlen], - rtol=rtol, - atol=atol, - equal_nan=True, - ) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1592,8 +1425,6 @@ def parity_check_gqa_prompt( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, - " qk_output:", - config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1700,28 +1531,12 @@ def parity_check_gqa_prompt_no_buff( else None ) - head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None - - output_qk = ( - torch.zeros( - config.batch_size, - config.num_heads, - config.kv_sequence_length, - config.q_sequence_length, - device="cpu", - dtype=torch_type, - requires_grad=False, - ) - if config.qk_output != QKOutputType.NO_OUTPUT - else None - ) - brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( + out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1733,7 +1548,6 @@ def parity_check_gqa_prompt_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1744,7 +1558,7 @@ def parity_check_gqa_prompt_no_buff( # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v, out_qk = gqa_prompt_func( + out, present_k, present_v = gqa_prompt_func( packed_qkv, None, None, @@ -1756,8 +1570,6 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, - head_sink, - output_qk, left_window_size, past_format, False, @@ -1768,7 +1580,7 @@ def parity_check_gqa_prompt_no_buff( numpy_type=numpy_type, ) else: - out, present_k, present_v, out_qk = gqa_prompt_func( + out, present_k, present_v = gqa_prompt_func( q, None, None, @@ -1780,8 +1592,6 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, - head_sink, - output_qk, left_window_size, past_format, False, @@ -1795,22 +1605,6 @@ def parity_check_gqa_prompt_no_buff( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - if config.qk_output != QKOutputType.NO_OUTPUT: - out_qk_ref = ( - out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref - ) - out_qk_ref = out_qk_ref.detach().cpu().numpy() - - for batch_idx in range(config.batch_size): - total_seqlen = cache_seqlens[batch_idx] - assert numpy.allclose( - out_qk[batch_idx, :, :, :total_seqlen], - out_qk_ref[batch_idx, :, :, :total_seqlen], - rtol=rtol, - atol=atol, - equal_nan=True, - ) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1852,8 +1646,6 @@ def parity_check_gqa_prompt_no_buff( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, - " qk_output:", - config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1967,8 +1759,6 @@ def parity_check_gqa_past( cos, sin = None, None q_ro, k_ro = q, new_k - head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None - arange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -1979,7 +1769,7 @@ def parity_check_gqa_past( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( + out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1991,7 +1781,6 @@ def parity_check_gqa_past( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -2018,24 +1807,10 @@ def parity_check_gqa_past( else None ) - output_qk = ( - torch.zeros( - config.batch_size, - config.num_heads, - config.sequence_length, - config.kv_sequence_length, - device="cpu", - dtype=torch_type, - requires_grad=False, - ) - if config.qk_output != QKOutputType.NO_OUTPUT - else None - ) - # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v, out_qk = gqa_past_func( + out, present_k, present_v = gqa_past_func( packed_qkv, k, v, @@ -2047,8 +1822,6 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, - head_sink, - output_qk, past_format, True, left_window_size, @@ -2059,7 +1832,7 @@ def parity_check_gqa_past( numpy_type=numpy_type, ) else: - out, present_k, present_v, out_qk = gqa_past_func( + out, present_k, present_v = gqa_past_func( q, k, v, @@ -2071,8 +1844,6 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, - head_sink, - output_qk, past_format, True, left_window_size, @@ -2086,22 +1857,6 @@ def parity_check_gqa_past( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - if config.qk_output != QKOutputType.NO_OUTPUT: - out_qk_ref = ( - out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref - ) - out_qk_ref = out_qk_ref.detach().cpu().numpy() - - for batch_idx in range(config.batch_size): - total_seqlen = cache_seqlens[batch_idx] + 1 - assert numpy.allclose( - out_qk[batch_idx, :, :, :total_seqlen], - out_qk_ref[batch_idx, :, :, :total_seqlen], - rtol=rtol, - atol=atol, - equal_nan=True, - ) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -2127,8 +1882,6 @@ def parity_check_gqa_past( softcap, " smooth_softmax:", use_smooth_softmax, - " head_sink:", - config.has_head_sink, " B:", config.batch_size, " S:", @@ -2145,8 +1898,6 @@ def parity_check_gqa_past( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, - " qk_output:", - config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2266,8 +2017,6 @@ def parity_check_gqa_past_no_buff( cos, sin = None, None q_ro, k_ro = q, new_k - head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None - arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -2278,7 +2027,7 @@ def parity_check_gqa_past_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( + out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -2290,7 +2039,6 @@ def parity_check_gqa_past_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -2317,24 +2065,10 @@ def parity_check_gqa_past_no_buff( else None ) - output_qk = ( - torch.zeros( - config.batch_size, - config.num_heads, - config.sequence_length, - config.kv_sequence_length + config.sequence_length, - device="cpu", - dtype=torch_type, - requires_grad=False, - ) - if config.qk_output != QKOutputType.NO_OUTPUT - else None - ) - # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v, out_qk = gqa_past_func( + out, present_k, present_v = gqa_past_func( packed_qkv, k, v, @@ -2346,8 +2080,6 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, - head_sink, - output_qk, past_format, False, window_size=left_window_size, @@ -2358,7 +2090,7 @@ def parity_check_gqa_past_no_buff( numpy_type=numpy_type, ) else: - out, present_k, present_v, out_qk = gqa_past_func( + out, present_k, present_v = gqa_past_func( q, k, v, @@ -2370,8 +2102,6 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, - head_sink, - output_qk, past_format, False, window_size=left_window_size, @@ -2385,22 +2115,6 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - if config.qk_output != QKOutputType.NO_OUTPUT: - out_qk_ref = ( - out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref - ) - out_qk_ref = out_qk_ref.detach().cpu().numpy() - - for batch_idx in range(config.batch_size): - total_seqlen = cache_seqlens[batch_idx] + 1 - assert numpy.allclose( - out_qk[batch_idx, :, :, :total_seqlen], - out_qk_ref[batch_idx, :, :, :total_seqlen], - rtol=rtol, - atol=atol, - equal_nan=True, - ) - # Compare results all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET @@ -2420,8 +2134,6 @@ def parity_check_gqa_past_no_buff( softcap, " smooth_softmax:", use_smooth_softmax, - " head_sink:", - config.has_head_sink, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -2440,8 +2152,6 @@ def parity_check_gqa_past_no_buff( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, - " qk_output:", - config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2470,16 +2180,7 @@ def setUp(self): ] def run_test_config( - self, - test_func, - config_class, - batches, - seqs, - num_h, - h_sizes, - pos_ids_attn_bias, - qk_output, - additional_params=None, + self, test_func, config_class, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, additional_params=None ): if additional_params is None: additional_params = {} @@ -2501,59 +2202,33 @@ def run_test_config( for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: for has_pos, has_attn in pos_ids_attn_bias: - for head_sink in [False, True]: - if use_smooth_softmax and head_sink: - continue - for output_qk in qk_output: - if config_class == PromptConfig: - config = config_class( - b, - s, - s2, - s + s2 + 8, - n, - n2, - h, - has_pos, - has_attn, - head_sink, - output_qk, - ) - else: # Config - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = config_class( - b, - s, - s2, - sp, - n, - n2, - h, - has_pos, - has_attn, - head_sink, - output_qk, - ) - - params = { - "config": config, - "torch_type": precision["torch_type"], - "numpy_type": precision["numpy_type"], - "ort_type": precision["ort_type"], - "rtol": precision["rtol"], - "atol": precision["atol"], - "local": local, - "past_format": Formats.BNSH, - "rotary": rotary, - "rotary_interleaved": rotary_interleaved, - "packed": packed, - "softcap": softcap, - "use_smooth_softmax": use_smooth_softmax, - } - params.update(additional_params) - - all_close = test_func(**params) - self.assertTrue(all_close) + if config_class == PromptConfig: + config = config_class( + b, s, s2, s + s2 + 8, n, n2, h, has_pos, has_attn + ) + else: # Config + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = config_class(b, s, s2, sp, n, n2, h, has_pos, has_attn) + + params = { + "config": config, + "torch_type": precision["torch_type"], + "numpy_type": precision["numpy_type"], + "ort_type": precision["ort_type"], + "rtol": precision["rtol"], + "atol": precision["atol"], + "local": local, + "past_format": Formats.BNSH, + "rotary": rotary, + "rotary_interleaved": rotary_interleaved, + "packed": packed, + "softcap": softcap, + "use_smooth_softmax": use_smooth_softmax, + } + params.update(additional_params) + + all_close = test_func(**params) + self.assertTrue(all_close) def test_gqa_no_past(self): print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") @@ -2570,33 +2245,12 @@ def test_gqa_no_past(self): ) num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - qk_output = ( - [QKOutputType.NO_OUTPUT] - if pipeline_mode - else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] - ) # Test with buffer - self.run_test_config( - parity_check_gqa_prompt, - PromptConfig, - batches, - seqs, - num_h, - h_sizes, - pos_ids_attn_bias, - qk_output, - ) + self.run_test_config(parity_check_gqa_prompt, PromptConfig, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) # Test without buffer self.run_test_config( - parity_check_gqa_prompt_no_buff, - PromptConfig, - batches, - seqs, - num_h, - h_sizes, - pos_ids_attn_bias, - qk_output, + parity_check_gqa_prompt_no_buff, PromptConfig, batches, seqs, num_h, h_sizes, pos_ids_attn_bias ) def test_gqa_past(self): @@ -2614,25 +2268,11 @@ def test_gqa_past(self): ) num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - qk_output = ( - [QKOutputType.NO_OUTPUT] - if pipeline_mode - else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] - ) # Test with buffer - self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, qk_output) + self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) # Test without buffer - self.run_test_config( - parity_check_gqa_past_no_buff, - Config, - batches, - seqs, - num_h, - h_sizes, - pos_ids_attn_bias, - qk_output, - ) + self.run_test_config(parity_check_gqa_past_no_buff, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") @@ -2647,7 +2287,6 @@ def test_gqa_interactive_one_batch(self): if pipeline_mode else [(False, False), (True, True), (False, True), (True, False)] ) - qk_output = [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] @@ -2660,7 +2299,6 @@ def test_gqa_interactive_one_batch(self): num_h, h_sizes, pos_ids_attn_bias, - qk_output, additional_params={"softcap": 0.0, "use_smooth_softmax": False}, ) self.run_test_config( @@ -2671,7 +2309,6 @@ def test_gqa_interactive_one_batch(self): num_h, h_sizes, pos_ids_attn_bias, - qk_output, additional_params={"softcap": 0.0, "use_smooth_softmax": False}, ) diff --git a/onnxruntime/test/python/transformers/test_gqa_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py index 79976a92e54bf..2f5b638a57d0c 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cuda.py +++ b/onnxruntime/test/python/transformers/test_gqa_cuda.py @@ -782,8 +782,7 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - head_sink = None - attention = smooth_softmax_ref(scores, head_sink) + attention = smooth_softmax_ref(scores) else: attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py index ca5c9c2ce133f..410860a324a9d 100644 --- a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py +++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py @@ -401,8 +401,7 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - head_sink = None - attention = smooth_softmax_ref(scores, head_sink) + attention = smooth_softmax_ref(scores) else: attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_phi_vision.py b/onnxruntime/test/python/transformers/test_phi_vision.py index d276366706af9..67f89e633a146 100644 --- a/onnxruntime/test/python/transformers/test_phi_vision.py +++ b/onnxruntime/test/python/transformers/test_phi_vision.py @@ -149,7 +149,7 @@ def __init__(self): self.attn = PhiVCLIPAttention() self.ln = torch.nn.LayerNorm(20, eps=1e-05) - def forward(self, x, attention_mask=None): + def forward(self, x): # SkipLayerNorm ------+ # | | # Attention | @@ -163,7 +163,8 @@ def forward(self, x, attention_mask=None): x = self.ln(x) residual = x - x = self.attn(x, attention_mask=attention_mask) + # Attention + MatMul + x = self.attn(x) # SkipLayerNorm x = residual + x @@ -193,31 +194,14 @@ def verify_fusion(self, optimized_model, expected_model_filename): ) def export(self, model, inputs): - path = os.path.join(os.path.dirname(__file__), "export.onnx") - - if len(inputs) == 2: - torch.onnx.export( - model, - args=inputs, - f=path, - export_params=True, - opset_version=14, - do_constant_folding=True, - input_names=["input", "attention_mask"], - dynamic_axes={ - "input": {0: "batch", 1: "seq"}, - "attention_mask": {0: "batch", 2: "seq", 3: "seq"}, - }, - ) - else: - torch.onnx.export( - model, - args=inputs, - f=path, - export_params=True, - opset_version=14, - do_constant_folding=True, - ) + torch.onnx.export( + model, + args=inputs, + f=os.path.join(os.path.dirname(__file__), "export.onnx"), + export_params=True, + opset_version=14, + do_constant_folding=True, + ) def tearDown(self): path = os.path.join(os.path.dirname(__file__), "export.onnx") @@ -265,38 +249,6 @@ def test_phi_vision_attention(self): ) self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-attention.onnx") - def test_phi_vision_attention_with_mask(self): - model = PhiVCLIPAttentionAndLayerNorm() - - batch, seq_len, dim = 1, 2, 20 - mask = torch.zeros(batch, 1, seq_len, seq_len) - mask[:, 1:] = float("-inf") - - inputs = (torch.randn(batch, seq_len, dim), mask) - self.export(model, inputs) - original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) - options = FusionOptions("clip") - optimized_model = optimize_model( - original_model, - model_type="clip", - num_heads=2, - hidden_size=20, - optimization_options=options, - opt_level=0, - use_gpu=True, - ) - self.verify_fusion(optimized_model, "phi-4-v-instruct-vision-attention.onnx") - - graph = optimized_model.model.graph - attention_node = next((n for n in graph.node if n.name == "Attention_0"), None) - self.assertIsNotNone(attention_node, "Could not find the Attention fused node") - attr_names = [attr.name for attr in attention_node.attribute] - self.assertNotIn( - "unidirectional", - attr_names, - f"The attention node should not have a 'unidirectional' attribute: {attr_names}", - ) - if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx deleted file mode 100644 index d036541a70aa087f6007ec7261f5f1115b0e22f2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1892 zcmc&#&2G~`5Y8GWaV8~=whcrEAaP&c$9+r)V_e7+CMS24QOn};JS@5uuwhojaXfGV8^v_J5P zdou3)00^KKITpQ`lc^PqiAS-P$a?eD%nd@~hP}}dH(80r++4G?cA)rhT<3}SzV zo0H)NM;CKS-}5BRvOL4jGD9eH!WEX$*~psB!&{MZ1XACWRiwTs0x4Tok{~7J8<3Kg zyCC)P1w$%{yoS^cLrIz#W*xi{bu2O*#%bu4m&0LSw8Ff{j<~^0?WofhLE6E5aOx9p z+-hn{z1&rhvR6xj#l0Rpft7%`1{)f}8YmiKUuB|a&*yC0BB8Y#SE$(ftwJ>%Q#WDR zcU7`1t}VgNkwx6ZGU0g_>iSUC`&kVX?S*?o`uvGX(i5vu@IN| z?*bMeM{k4CleJI+1N)Ij+#zSHS&GlH!_In#AF`<{cM)O@maxVVTc1$c`~O>;pxRP( zIXcxvj{r1AK$R14@*wLUUe<5PZY?Vr@Ax{%IKP6((s~;_f@}%gISCP9`Mn8CLM$zz ztjLTTtJ|gos#d`TJ`=yt>P%cC=%)K$iEO=@Z0!RYCTsuD^;qlE&7WDoU|`u|{wi#u zIc1`bUgE^=dGRX1OxccXz73K+AWBcXbEQ8{)4^}*ur}5cKJ0ex&TT6IS7u(=7R%@O gpK*_GjBuRe!e9%~W$t+nclOVTCER-|6zZFQ0aGjCcK`qY diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index f02e3e8058c29..450b955f161af 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -1,6 +1,6 @@ # This file is auto updated by dependabot # When any package below is changed, you shall run "lintrunner init" again. lintrunner==0.12.7 -lintrunner-adapters==0.12.5 -ruff==0.12.3 -clang-format==20.1.8 +lintrunner-adapters==0.12.4 +ruff==0.12.2 +clang-format==20.1.7 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 893f3c80fa4b8..f6e37d33b2414 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -284,8 +284,6 @@ def generate_vcpkg_install_options(build_dir, args): vcpkg_install_options.append("--x-feature=vsinpu-ep") if args.use_webgpu: vcpkg_install_options.append("--x-feature=webgpu-ep") - if args.wgsl_template == "dynamic": - vcpkg_install_options.append("--x-feature=webgpu-ep-wgsl-template-dynamic") if args.use_webnn: vcpkg_install_options.append("--x-feature=webnn-ep") if args.use_xnnpack: @@ -472,7 +470,6 @@ def generate_build_tree( else "OFF" ), "-Donnxruntime_REDUCED_OPS_BUILD=" + ("ON" if is_reduced_ops_build(args) else "OFF"), - "-Donnxruntime_CLIENT_PACKAGE_BUILD=" + ("ON" if args.client_package_build else "OFF"), "-Donnxruntime_BUILD_MS_EXPERIMENTAL_OPS=" + ("ON" if args.ms_experimental else "OFF"), "-Donnxruntime_ENABLE_LTO=" + ("ON" if args.enable_lto else "OFF"), "-Donnxruntime_USE_ACL=" + ("ON" if args.use_acl else "OFF"), diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 53d53f3e15e99..ad27b8124c458 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -527,15 +527,6 @@ def add_size_reduction_args(parser: argparse.ArgumentParser) -> None: ) -def add_client_package_args(parser: argparse.ArgumentParser) -> None: - """Adds arguments for client package build package.""" - parser.add_argument( - "--client_package_build", - action="store_true", - help="Create ORT package with default settings more appropriate for client/on-device workloads.", - ) - - def add_python_binding_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for Python bindings.""" parser.add_argument("--enable_pybind", action="store_true", help="Enable Python bindings.") @@ -842,7 +833,6 @@ def convert_arg_line_to_args(self, arg_line: str) -> list[str]: # Use list[str] add_dependency_args(parser) add_extension_args(parser) add_size_reduction_args(parser) - add_client_package_args(parser) # Language Bindings add_python_binding_args(parser) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index e5e2a4749ef85..ee7f8f2fa386a 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 202aa61da0b80..aa25e3f31166a 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 69dc9d1a8f63d..7addb3217072a 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,7 +6,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 526ed71df2006..cf8bbbed70525 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index b99246625cb77..de024f0b3456f 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.1.250708 + default: 2.36.0.250627 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 626a638121858..4fa916db0de39 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index a87bb55441ac7..84b6d30ee32ac 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -72,8 +72,6 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} - - template: ../templates/set-version-number-variables-step.yml - # Reconstruct the build dir - task: PowerShell@2 displayName: 'PS: Extract nuget files gpu' @@ -116,7 +114,6 @@ stages: -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - -p:PackageVersion=$(OnnxRuntimeVersion) workingDirectory: '$(Build.SourcesDirectory)\csharp' - template: ../templates/win-esrp-dll.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index e2c6b25f48b6d..433250f05125e 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.1.250708 + default: 2.36.0.250627 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 74f7f782fe1b2..ab779e164b36e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.1.250708' + default: '2.36.0.250627' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 92e862bd79008..110f83ff587c8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.1.250708' + default: '2.36.0.250627' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 5b48a14e2afc3..535784933a087 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 930dc83b73460..3e7427cc7a2e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.1.250708' + default: '2.36.0.250627' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 96eea6cd6d2fb..e3f549e2d649f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.1.250708' + default: '2.36.0.250627' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index caee5367950e6..d533fb7c83ddd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 185f41822a7e5..cd060d1fbf19f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 9a1e7e5e251c9..2a2ac49b4e073 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 5affc152a0a4a..8528fa3907e96 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 29ebb8c4e4e61..1406ce338f13e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.36.1.250708' + QnnSdk: '2.36.0.250627' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false @@ -20,7 +20,7 @@ stages: name: ${{ parameters.qnn_ep_build_pool_name }} variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} - commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --client_package_build --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' + commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' steps: - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 7ebf5394e4530..78fce1f9b9602 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 jobs: - job: 'BUILD_QNN_EP' @@ -50,7 +50,7 @@ jobs: matrix: SHARED_LIB: QnnLibKind: 'shared_lib' - ExtraQnnBuildArgs: '--client_package_build' + ExtraQnnBuildArgs: '' STATIC_LIB: QnnLibKind: 'static_lib' ExtraQnnBuildArgs: '' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index ffeb577547f69..eb77c9422853d 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 jobs: - job: 'BUILD_QNN_EP' From ea8cf2173c853fd0ab8f78b45b3ed5897024e2a5 Mon Sep 17 00:00:00 2001 From: n1harika Date: Thu, 17 Jul 2025 05:21:43 -0700 Subject: [PATCH 068/138] Added support for 2025.2 and enabled SimplifiedLayerNormalization op (#714) * Added support for 2025.2 and SimplifiedLayerNormalization op * [OVEP] Update OV version to 2025.2.0 * Revert "[OVEP] Update OV version to 2025.2.0" This reverts commit d1292507f41ddd5ac0747728f73ca6100b00a567. --- cmake/onnxruntime_providers_openvino.cmake | 4 ++-- .../providers/openvino/ov_versions/capability.cc | 10 ++++------ .../core/providers/openvino/ov_versions/data_ops.cc | 13 ++++++++----- .../core/providers/openvino/ov_versions/data_ops.h | 3 ++- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index 552f4cd3b8988..5a831a106ae08 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -13,8 +13,8 @@ # Header paths find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) - if(OpenVINO_VERSION VERSION_LESS 2024.5) - message(FATAL_ERROR "OpenVINO 2024.5 and newer are supported. Please, use latest OpenVINO release") + if(OpenVINO_VERSION VERSION_LESS 2025.0) + message(FATAL_ERROR "OpenVINO 2025.0 and newer are supported. Please, use latest OpenVINO release") endif() if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 88ddde8610c6e..2309ff3de751b 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -41,16 +41,14 @@ GetCapability::GetCapability(const EPCtxHandler& ep_ctx_handler, npu_qdq_optimizer_enabled = true; // see data_ops.cc ~615 where we check for int16 types for gpu, this may change to a better approach later } -#if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 5 - data_ops_ = new DataOps(graph_viewer_, V_2024_5, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 6 - data_ops_ = new DataOps(graph_viewer_, V_2024_6, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 0 +#if OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 0 data_ops_ = new DataOps(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled); #elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 1 data_ops_ = new DataOps(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 2 + data_ops_ = new DataOps(graph_viewer_, V_2025_2, device_type_, npu_qdq_optimizer_enabled); #else - data_ops_ = new DataOps(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled); + data_ops_ = new DataOps(graph_viewer_, V_2025_2, device_type_, npu_qdq_optimizer_enabled); #endif } diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 84001c1161efc..336b294117cba 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -229,6 +229,7 @@ std::vector supported_op_mode = { {"Sigmoid", V_2020_4, {"CPU", "GPU"}}, {"Sign", V_2020_4, {"CPU"}}, {"Sign", V_2022_1, {"GPU"}}, + {"SimplifiedLayerNormalization", V_2025_2, {"CPU", "GPU"}}, {"Sin", V_2022_1, {"CPU", "GPU"}}, {"Sinh", V_2020_4, {"CPU"}}, {"Size", V_2022_1, {"CPU", "GPU"}}, @@ -402,7 +403,7 @@ void DataOps::populate_op_mode_supported() { // populate unsupportedmode_t { - UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1}, + UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, V_2025_2}, [this](const Node* node, const InitializedTensorSet&) { // If the Input of ReduceMax op is UINT8, it is rejected (Due to output mismatch) for (size_t i = 0; i < node->InputDefs().size(); i++) { @@ -418,7 +419,8 @@ void DataOps::populate_op_mode_supported() { } { UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, - V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1}, + V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, + V_2025_2}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_args = node->InputDefs(); const auto& input_arg = (input_args.size() > 1) ? input_args[1] : input_args[0]; @@ -437,7 +439,8 @@ void DataOps::populate_op_mode_supported() { } { UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, - V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1}, + V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, + V_2025_2}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze // If axes is an input, then we cannot produce a static graph. @@ -452,8 +455,8 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Unsqueeze", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, - V_2025_0, V_2025_1}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, + V_2024_6, V_2025_0, V_2025_1, V_2025_2}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index cf7d834d6cfc7..95905e010541e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -35,7 +35,8 @@ enum versionNum { V_2024_5, V_2024_6, V_2025_0, - V_2025_1 + V_2025_1, + V_2025_2 }; using VersionNum = enum versionNum; From 217f2852573190526bb105efbbb2132c2cef179c Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:44:00 +0530 Subject: [PATCH 069/138] [OVEP] Fix coverity issues (#753) --- .../core/providers/openvino/backend_manager.cc | 11 ++++++++--- onnxruntime/core/providers/openvino/backend_utils.h | 2 +- .../core/providers/openvino/backends/basic_backend.cc | 2 +- .../providers/openvino/openvino_execution_provider.cc | 2 +- .../core/providers/openvino/openvino_parser_utils.cc | 6 +++--- onnxruntime/core/providers/openvino/ov_interface.cc | 2 +- onnxruntime/core/providers/openvino/ov_interface.h | 4 ++-- .../providers/openvino/ov_stateful_patch_utils.cc | 2 +- .../openvino/qdq_transformations/qdq_scales_fix.cpp | 4 ++-- 9 files changed, 20 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 28804d2f76492..41e5aca08eab5 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -183,9 +183,13 @@ BackendManager::BackendManager(SessionContext& session_context, } if (session_context_.so_context_enable && (subgraph_context_.is_ep_ctx_ovir_encapsulated || !subgraph_context_.is_ep_ctx_graph)) { - auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph); - if (!status.IsOK()) { - ORT_THROW(status); + if (concrete_backend_) { + auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph); + if (!status.IsOK()) { + ORT_THROW(status); + } + } else { + ORT_THROW("[OpenVINO-EP] Cannot export compiled blob as EPCtx Node: Backend not initialized."); } } } @@ -660,6 +664,7 @@ void BackendManager::Compute(OrtKernelContext* context) { } void BackendManager::ShutdownBackendManager() { + std::unique_lock lock(mutex_); backend_map_.clear(); concrete_backend_.reset(); } diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index ec3df94c2d1d2..15145df651fa2 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -37,7 +37,7 @@ struct ParameterShape { std::transform(ort_shape.begin(), ort_shape.end(), ov_shape.begin(), [](int64_t dim) { return dim == -1 ? ov::Dimension::dynamic() : ov::Dimension(dim); }); - return ov::PartialShape(ov_shape); + return ov::PartialShape(std::move(ov_shape)); } static ort_shape_t ToOrtShape(const ov::PartialShape& ov_shape) { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 61235ef2138b5..8b7309e6a5a98 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -44,7 +44,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr // model_file_path will use so_context_file_path if the onnx_model_path_name is not available, // especially in case of CreateSessionFormArray() where user must explicitly // specify absolute path for so_context_file_path. - auto model_file_path = [this]() { + auto model_file_path = [this]() -> const std::filesystem::path& { if (!session_context_.onnx_model_path_name.empty() && std::filesystem::exists(session_context_.onnx_model_path_name)) return session_context_.onnx_model_path_name; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index a0aa04293ac37..1b19517b07363 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -118,7 +118,7 @@ common::Status OpenVINOExecutionProvider::Compile( fs::path metadata_file_path = context_model_file_path.parent_path() / metadata_filename; std::ifstream file(metadata_file_path, std::ios::binary); ORT_RETURN_IF_NOT(file, "Metadata file was not found: " + metadata_file_path.string()); - shared_context_->shared_weights.metadata_filepath = metadata_file_path; + shared_context_->shared_weights.metadata_filepath = std::move(metadata_file_path); file >> metadata; } diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc index a78bd1fe2effc..21fc7f935da23 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc @@ -142,7 +142,7 @@ reshape_t OpenVINOParserUtils::ParseInputShape(const std::string& reshape_input_ } // Process each tensor definition e.g. "input_1[1..5, 2, 3..4],data[1,2,3]" - for (std::sregex_iterator i = tensor_begin; i != tensor_end; ++i) { + for (std::sregex_iterator i = std::move(tensor_begin); i != tensor_end; ++i) { std::smatch tensor_match = *i; // Extract tensor name and trim whitespace @@ -165,7 +165,7 @@ reshape_t OpenVINOParserUtils::ParseInputShape(const std::string& reshape_input_ auto dim_end = std::sregex_iterator(); // Process each dimension - for (std::sregex_iterator j = dim_begin; j != dim_end; ++j) { + for (std::sregex_iterator j = std::move(dim_begin); j != dim_end; ++j) { std::smatch dim_match = *j; std::string dim_value = dim_match[1].str(); @@ -190,7 +190,7 @@ reshape_t OpenVINOParserUtils::ParseInputShape(const std::string& reshape_input_ } // Store parsed shape in result map - parsed_shape_map[tensor_name] = ov::PartialShape(dimensions); + parsed_shape_map[tensor_name] = ov::PartialShape(std::move(dimensions)); } return parsed_shape_map; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index c59cc92d6cfa9..2d29df8eb4197 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -331,7 +331,7 @@ OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) { } std::string OVInferRequest::GetInputTensorName(uint32_t index) { - return OvExceptionBoundary([&]() { + return OvExceptionBoundary([&]() -> const std::string& { const auto& model = ovInfReq.get_compiled_model(); return *model.input(index).get_names().begin(); }, diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index f6bc5ad599e18..ee35a3ebef7cb 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -105,8 +105,8 @@ class OVExeNetwork { public: explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false) - : compiled_model_obj(compiled_model), target_device(device), is_stateful_causallm(stateful_causallm) {} - OVExeNetwork() : compiled_model_obj(ov::CompiledModel()) {} + : compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm) {} + OVExeNetwork() : compiled_model_obj(ov::CompiledModel()), is_stateful_causallm(false) {} ov::CompiledModel& Get() { return compiled_model_obj; } std::shared_ptr CreateInferRequest(); }; diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index 67ba42884e4f0..b48b0efde7ab6 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -74,7 +74,7 @@ void FuseCacheReorder(std::shared_ptr ov_model, auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; - auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape({input_batch})); + auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape({std::move(input_batch)})); beam_idx->set_friendly_name("beam_idx"); beam_idx->output(0).get_tensor().add_names({"beam_idx"}); ov_model->add_parameters({beam_idx}); diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index c1e4815c206a2..d159930d52845 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -512,7 +512,7 @@ struct CustomGraph { continue; } - auto scale_name = node->node_input_name[1]; // Scale + const auto& scale_name = node->node_input_name[1]; // Scale auto scale_value = get_initializer_value(original_graph, scale_name); if (scale_value / node->scale_factor < threshold) { remove_qdq_pair(*node, removed); @@ -699,7 +699,7 @@ bool scale_graph(CustomGraph& gen_graph, if (cur_node->op_type == "QuantizeLinear" && cur_node->to_node[0]->op_type == "DequantizeLinear") { needs_second_run = true; - auto scale_name = *std::next(cur_node->node_input_name.begin()); + const auto& scale_name = *std::next(cur_node->node_input_name.begin()); auto scale_value = get_initializer_value(gen_graph.original_graph, scale_name); // QDQ pair with scale over 1 From 138c6e0536c9fbe94285163372180e9a218bacc6 Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Mon, 21 Jul 2025 21:42:07 +0530 Subject: [PATCH 070/138] [OVEP] feat: Integrate new ABI with Legacy OVEP Plugin (#747) * update: Implement OV Plugin using factories * fix: refactor plugin code * fix: map ep_metadata to device type using "ov_device" key * fix: block provider options for AppendExecutionProvider_V2 pass * minor fix for linux * Add OrtEpLibraryOv tests * ovep: Support multiple devices (i.e. AUTO) passed to CreateIExecutionProvider * CreateIExecutionProvider: comment out unused devices parameter * ovep factory: Implement CreateDataTransfer to avoid crash in RegisterExecutionProviderLibrary * update: Enable shared libs linker flags for linux & macos * CreateIExecutionProvider: For some disallowed provider options, give better guidance * Add PluginEp_CheckV2DisallowedProviderOptions test * ovep: Add CreateProvider_V2 & call it from CreateIExecutionProvider * disable data transfer for ovep * minor fix for linux * openvino_provider_factory: Add 'num_of_threads' to block_and_advise_entries --------- Co-authored-by: Ryan Metcalfe --- .../providers/openvino/exported_symbols.lst | 2 + .../openvino/openvino_provider_factory.cc | 122 +++++++ .../core/providers/openvino/ov_factory.cc | 182 +++++++++++ .../core/providers/openvino/ov_factory.h | 156 +++++++++ .../core/providers/openvino/ov_interface.h | 2 +- .../core/providers/openvino/symbols.def | 2 + .../providers/openvino/version_script.lds | 4 +- .../providers/openvino/openvino_plugin.cc | 302 ++++++++++++++++++ 8 files changed, 770 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/core/providers/openvino/ov_factory.cc create mode 100644 onnxruntime/core/providers/openvino/ov_factory.h create mode 100644 onnxruntime/test/providers/openvino/openvino_plugin.cc diff --git a/onnxruntime/core/providers/openvino/exported_symbols.lst b/onnxruntime/core/providers/openvino/exported_symbols.lst index f4c41412594af..6dc5905ae4550 100644 --- a/onnxruntime/core/providers/openvino/exported_symbols.lst +++ b/onnxruntime/core/providers/openvino/exported_symbols.lst @@ -1 +1,3 @@ _GetProvider +_CreateEpFactories +_ReleaseEpFactory diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index bad1d416eeda2..fda7ef6534197 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -403,6 +403,18 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { return ov_ep; } + // This is called during session creation when AppendExecutionProvider_V2 is used. + // This one is called because ParseProviderInfo / ParseConfigOptions, etc. are already + // performed in CreateIExecutionProvider, and so provider_info_ has already been populated. + std::unique_ptr CreateProvider_V2(const OrtSessionOptions& /*session_options*/, + const OrtLogger& session_logger) { + ProviderInfo provider_info = provider_info_; + auto ov_ep = std::make_unique(provider_info, shared_context_); + ov_ep->SetLogger(reinterpret_cast(&session_logger)); + return ov_ep; + } + + private: ProviderInfo provider_info_; std::shared_ptr shared_context_; @@ -433,6 +445,116 @@ struct OpenVINO_Provider : Provider { return std::make_shared(pi, SharedContext::Get()); } + Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* ep_metadata, + size_t num_devices, + ProviderOptions& provider_options, + const OrtSessionOptions& session_options, + const OrtLogger& logger, + std::unique_ptr& ep) override { + // Check if no devices are provided + if (num_devices == 0) { + return Status(common::ONNXRUNTIME, ORT_EP_FAIL, "No devices provided to CreateIExecutionProvider"); + } + + // For provider options that we don't support directly but are still supported through load_config, + // give some specific guidance & example about how to make use of the option through load_config. + const std::vector> block_and_advise_entries = { + {"cache_dir", "\"CACHE_DIR\": \"\""}, + {"precision", "\"INFERENCE_PRECISION_HINT\": \"F32\""}, + {"num_of_threads", "\"INFERENCE_NUM_THREADS\": \"1\""}, + {"num_streams", "\"NUM_STREAMS\": \"1\""}, + {"model_priority", "\"MODEL_PRIORITY\": \"LOW\""}, + {"enable_opencl_throttling", "\"GPU\": {\"PLUGIN_THROTTLE\": \"1\"}"}, + {"enable_qdq_optimizer", "\"NPU\": {\"NPU_QDQ_OPTIMIZATION\": \"YES\"}"} + }; + + for (auto& block_and_advise_entry : block_and_advise_entries) { + if (provider_options.find(block_and_advise_entry.first) != provider_options.end()) { + std::string message = "OpenVINO EP: Option '" + block_and_advise_entry.first + + "' cannot be set when using AppendExecutionProvider_V2. " + + "It can instead be enabled by a load_config key / value pair. For example: " + + block_and_advise_entry.second; + return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT, message); + } + } + + // For the rest of the disallowed provider options, give a generic error message. + const std::vector blocked_provider_keys = { + "device_type", "device_id", "device_luid", "context", "disable_dynamic_shapes"}; + + for (const auto& key : blocked_provider_keys) { + if (provider_options.find(key) != provider_options.end()) { + return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT, + "OpenVINO EP: Option '" + key + "' cannot be set when using AppendExecutionProvider_V2."); + } + } + + const char* ov_device_key = "ov_device"; + const char* ov_meta_device_key = "ov_meta_device"; + + // Create a unique list of ov_devices that were passed in. + std::unordered_set unique_ov_devices; + std::vector ordered_unique_ov_devices; + for (size_t i = 0; i < num_devices; ++i) { + const auto& device_meta_data = ep_metadata[i]; + auto ov_device_it = device_meta_data->Entries().find(ov_device_key); + if (ov_device_it == device_meta_data->Entries().end()) { + return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT, "OpenVINO EP device metadata not found."); + } + auto &ov_device = ov_device_it->second; + + // Add to ordered_unique only if not already present + if (unique_ov_devices.insert(ov_device).second) { + ordered_unique_ov_devices.push_back(ov_device); + } + } + + std::string ov_meta_device_type = "NONE"; + { + auto ov_meta_device_it = ep_metadata[0]->Entries().find(ov_meta_device_key); + if (ov_meta_device_it != ep_metadata[0]->Entries().end()) { + ov_meta_device_type = ov_meta_device_it->second; + } + } + + bool is_meta_device_factory = (ov_meta_device_type != "NONE"); + + if (ordered_unique_ov_devices.size() > 1 && !is_meta_device_factory) { + LOGS_DEFAULT(WARNING) << "[OpenVINO EP] Multiple devices were specified that are not OpenVINO meta devices. Using first ov_device only: " << ordered_unique_ov_devices.at(0); + ordered_unique_ov_devices.resize(1); // Use only the first device if not a meta device factory + } + + std::string ov_device_string; + if (is_meta_device_factory) { + // Build up a meta device string based on the devices that are passed in. E.g. AUTO:NPU,GPU.0,CPU + ov_device_string = ov_meta_device_type; + ov_device_string += ":"; + } + + bool prepend_comma = false; + for (const auto& ov_device : ordered_unique_ov_devices) { + if (prepend_comma) { + ov_device_string += ","; + } + ov_device_string += ov_device; + prepend_comma = true; + } + + provider_options["device_type"] = ov_device_string; + + // Parse provider info with the device type + ProviderInfo pi; + const auto& config_options = session_options.GetConfigOptions(); + ParseProviderInfo(provider_options, &config_options, pi); + ParseConfigOptions(pi); + + // Create and return the execution provider + auto factory = std::make_unique(pi, SharedContext::Get()); + ep = factory->CreateProvider_V2(session_options, logger); + return Status::OK(); + } + void Initialize() override { } diff --git a/onnxruntime/core/providers/openvino/ov_factory.cc b/onnxruntime/core/providers/openvino/ov_factory.cc new file mode 100644 index 0000000000000..e347bcf1b1aef --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_factory.cc @@ -0,0 +1,182 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include +#include +#include +#include +#include +#include +#include + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include "onnxruntime_c_api.h" +#include "ov_factory.h" +#include "openvino/openvino.hpp" +#include "ov_interface.h" + +using namespace onnxruntime::openvino_ep; +using ov_core_singleton = onnxruntime::openvino_ep::WeakSingleton; + +static void InitCxxApi(const OrtApiBase& ort_api_base) { + static std::once_flag init_api; + std::call_once(init_api, [&]() { + const OrtApi* ort_api = ort_api_base.GetApi(ORT_API_VERSION); + Ort::InitApi(ort_api); + }); +} + +OpenVINOEpPluginFactory::OpenVINOEpPluginFactory(ApiPtrs apis, const std::string& ov_metadevice_name, std::shared_ptr core) + : ApiPtrs{apis}, + ep_name_(ov_metadevice_name.empty() ? provider_name_ : std::string(provider_name_) + "." + ov_metadevice_name), + device_type_(ov_metadevice_name), + ov_core_(std::move(core)) { + OrtEpFactory::GetName = GetNameImpl; + OrtEpFactory::GetVendor = GetVendorImpl; + OrtEpFactory::GetVendorId = GetVendorIdImpl; + OrtEpFactory::GetSupportedDevices = GetSupportedDevicesImpl; + OrtEpFactory::GetVersion = GetVersionImpl; + OrtEpFactory::CreateDataTransfer = CreateDataTransferImpl; + + ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. +} + +const std::vector& OpenVINOEpPluginFactory::GetOvDevices() { + static std::vector devices = ov_core_singleton::Get()->get_available_devices(); + return devices; +} + +const std::vector& OpenVINOEpPluginFactory::GetOvMetaDevices() { + static std::vector virtual_devices = [ov_core = ov_core_singleton::Get()] { + std::vector supported_virtual_devices{}; + for (const auto& meta_device : known_meta_devices_) { + try { + ov_core->get_property(meta_device, ov::supported_properties); + supported_virtual_devices.push_back(meta_device); + } catch (ov::Exception&) { + // meta device isn't supported. + } + } + return supported_virtual_devices; + }(); + + return virtual_devices; +} + +OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + + // Create a map for device type mapping + static const std::map ort_to_ov_device_name = { + {OrtHardwareDeviceType::OrtHardwareDeviceType_CPU, "CPU"}, + {OrtHardwareDeviceType::OrtHardwareDeviceType_GPU, "GPU"}, + {OrtHardwareDeviceType::OrtHardwareDeviceType_NPU, "NPU"}, + }; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (ort_api.HardwareDevice_VendorId(&device) != vendor_id_) { + // Not an Intel Device. + continue; + } + + auto device_type = ort_api.HardwareDevice_Type(&device); + auto device_it = ort_to_ov_device_name.find(device_type); + if (device_it == ort_to_ov_device_name.end()) { + // We don't know about this device type + continue; + } + + const auto& ov_device_type = device_it->second; + std::string ov_device_name; + auto get_pci_device_id = [&](const std::string& ov_device) { + try { + ov::device::PCIInfo pci_info = ov_core_->get_property(ov_device, ov::device::pci_info); + return pci_info.device; + } catch (ov::Exception&) { + return 0u; // If we can't get the PCI info, we won't have a device ID. + } + }; + + auto filtered_devices = GetOvDevices(ov_device_type); + auto matched_device = filtered_devices.begin(); + if (filtered_devices.size() > 1 && device_type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // If there are multiple devices of the same type, we need to match by device ID. + matched_device = std::find_if(filtered_devices.begin(), filtered_devices.end(), [&](const std::string& ov_device) { + uint32_t ort_device_id = ort_api.HardwareDevice_DeviceId(&device); + return ort_device_id == get_pci_device_id(ov_device); + }); + } + + if (matched_device == filtered_devices.end()) { + // We didn't find a matching OpenVINO device for the OrtHardwareDevice. + continue; + } + + // these can be returned as nullptr if you have nothing to add. + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + ort_api.CreateKeyValuePairs(&ep_metadata); + ort_api.AddKeyValuePair(ep_metadata, ov_device_key_, matched_device->c_str()); + + if (IsMetaDeviceFactory()) { + ort_api.AddKeyValuePair(ep_metadata, ov_meta_device_key_, device_type_.c_str()); + } + + // Create EP device + auto* status = ort_api.GetEpApi()->CreateEpDevice(this, &device, ep_metadata, ep_options, + &ep_devices[num_ep_devices++]); + + ort_api.ReleaseKeyValuePairs(ep_metadata); + ort_api.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return status; + } + } + + return nullptr; +} + +extern "C" { +// +// Public symbols +// +OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + InitCxxApi(*ort_api_base); + const ApiPtrs api_ptrs{Ort::GetApi(), Ort::GetEpApi(), Ort::GetModelEditorApi()}; + + // Get available devices from OpenVINO + auto ov_core = ov_core_singleton::Get(); + std::vector supported_factories = {""}; + const auto& meta_devices = OpenVINOEpPluginFactory::GetOvMetaDevices(); + supported_factories.insert(supported_factories.end(), meta_devices.begin(), meta_devices.end()); + + const size_t required_factories = supported_factories.size(); + if (max_factories < required_factories) { + return Ort::Status(std::format("Not enough space to return EP factories. Need at least {} factories.", required_factories).c_str(), ORT_INVALID_ARGUMENT); + } + + size_t factory_index = 0; + for (const auto& device_name : supported_factories) { + // Create a factory for this specific device + factories[factory_index++] = new OpenVINOEpPluginFactory(api_ptrs, device_name, ov_core); + } + + *num_factories = factory_index; + return nullptr; +} + +OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} +} diff --git a/onnxruntime/core/providers/openvino/ov_factory.h b/onnxruntime/core/providers/openvino/ov_factory.h new file mode 100644 index 0000000000000..37739f67323c1 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_factory.h @@ -0,0 +1,156 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include + +#include "core/providers/shared_library/provider_api.h" +#include "openvino/openvino.hpp" + +namespace onnxruntime { +namespace openvino_ep { + +struct ApiPtrs { + const OrtApi& ort_api; + const OrtEpApi& ep_api; + const OrtModelEditorApi& model_editor_api; +}; + +#define OVEP_DISABLE_MOVE(class_name) \ + class_name(class_name&&) = delete; \ + class_name& operator=(class_name&&) = delete; + +#define OVEP_DISABLE_COPY(class_name) \ + class_name(const class_name&) = delete; \ + class_name& operator=(const class_name&) = delete; + +#define OVEP_DISABLE_COPY_AND_MOVE(class_name) \ + OVEP_DISABLE_COPY(class_name) \ + OVEP_DISABLE_MOVE(class_name) + +template +static auto ApiEntry(Func&& func, std::optional> logger = std::nullopt) { + try { + return func(); + } catch (const Ort::Exception& ex) { + if (logger) { + ORT_CXX_LOG_NOEXCEPT(logger->get(), ORT_LOGGING_LEVEL_ERROR, ex.what()); + } + if constexpr (std::is_same_v) { + return Ort::Status(ex.what(), ex.GetOrtErrorCode()).release(); + } + } catch (const std::exception& ex) { + if (logger) { + ORT_CXX_LOG_NOEXCEPT(logger->get(), ORT_LOGGING_LEVEL_ERROR, ex.what()); + } + if constexpr (std::is_same_v) { + return Ort::Status(ex.what(), ORT_RUNTIME_EXCEPTION).release(); + } + } catch (...) { + if (logger) { + ORT_CXX_LOG_NOEXCEPT(logger->get(), ORT_LOGGING_LEVEL_ERROR, "Unknown exception occurred."); + } + if constexpr (std::is_same_v) { + return Ort::Status("Unknown exception occurred.", ORT_RUNTIME_EXCEPTION).release(); + } + } +} + +class OpenVINOEpPluginFactory : public OrtEpFactory, public ApiPtrs { + public: + OpenVINOEpPluginFactory(ApiPtrs apis, const std::string& ov_device, std::shared_ptr ov_core); + ~OpenVINOEpPluginFactory() = default; + + OVEP_DISABLE_COPY_AND_MOVE(OpenVINOEpPluginFactory) + + static const std::vector& GetOvDevices(); + + std::vector GetOvDevices(const std::string& device_type) { + std::vector filtered_devices; + const auto& devices = GetOvDevices(); + std::copy_if(devices.begin(), devices.end(), std::back_inserter(filtered_devices), + [&device_type](const std::string& device) { + return device.find(device_type) != std::string::npos; + }); + return filtered_devices; + } + + static const std::vector& GetOvMetaDevices(); + + // Member functions + const char* GetName() const { + return ep_name_.c_str(); + } + + const char* GetVendor() const { + return vendor_; + } + + OrtStatus* GetSupportedDevices(const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices); + + bool IsMetaDeviceFactory() const { + return known_meta_devices_.find(device_type_) != known_meta_devices_.end(); + } + + // Constants + static constexpr const char* vendor_ = "Intel"; + static constexpr uint32_t vendor_id_{0x8086}; // Intel's PCI vendor ID + static constexpr const char* ov_device_key_ = "ov_device"; + static constexpr const char* ov_meta_device_key_ = "ov_meta_device"; + static constexpr const char* provider_name_ = "OpenVINOExecutionProvider"; + + private: + std::string ep_name_; + std::string device_type_; + std::vector ov_devices_; + std::shared_ptr ov_core_; + inline static const std::set known_meta_devices_ = { + "AUTO"}; + + public: + // Static callback methods for the OrtEpFactory interface + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->GetName(); + } + + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->GetVendor(); + } + + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return OpenVINOEpPluginFactory::vendor_id_; + } + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + auto* factory = static_cast(this_ptr); + return ApiEntry([&]() { return factory->GetSupportedDevices(devices, num_devices, ep_devices, max_ep_devices, p_num_ep_devices); }); + } + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; // return nullptr to indicate that this EP does not support data transfer. + return nullptr; + } + + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory*) noexcept { + return ORT_VERSION; + } +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index ee35a3ebef7cb..6d1db4366410b 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -159,7 +159,7 @@ class OVInferRequest { ov::InferRequest& GetNewObj() { return ovInfReq; } - virtual void RewindKVCache(size_t index) {} + virtual void RewindKVCache([[maybe_unused]] size_t index) {} }; class StatefulOVInferRequest : public OVInferRequest { diff --git a/onnxruntime/core/providers/openvino/symbols.def b/onnxruntime/core/providers/openvino/symbols.def index 4ec2f7914c208..3afed01da1966 100644 --- a/onnxruntime/core/providers/openvino/symbols.def +++ b/onnxruntime/core/providers/openvino/symbols.def @@ -1,2 +1,4 @@ EXPORTS GetProvider + CreateEpFactories + ReleaseEpFactory diff --git a/onnxruntime/core/providers/openvino/version_script.lds b/onnxruntime/core/providers/openvino/version_script.lds index 094abb3329781..3600a4f8f4b51 100644 --- a/onnxruntime/core/providers/openvino/version_script.lds +++ b/onnxruntime/core/providers/openvino/version_script.lds @@ -1,7 +1,9 @@ #_init and _fini should be local VERS_1.0 { global: - GetProvider; + GetProvider; + CreateEpFactories; + ReleaseEpFactory; # Hide everything else. local: diff --git a/onnxruntime/test/providers/openvino/openvino_plugin.cc b/onnxruntime/test/providers/openvino/openvino_plugin.cc new file mode 100644 index 0000000000000..5abca55820a24 --- /dev/null +++ b/onnxruntime/test/providers/openvino/openvino_plugin.cc @@ -0,0 +1,302 @@ +#include +#include + +#include "gtest/gtest.h" +#include "core/common/common.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "onnxruntime_cxx_api.h" +#include "api_asserts.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +extern std::unique_ptr ort_env; + +struct OrtEpLibraryOv : public ::testing::Test { + static const inline std::filesystem::path library_path = +#if _WIN32 + "onnxruntime_providers_openvino.dll"; +#else + "libonnxruntime_providers_openvino.so"; +#endif + static const inline std::string registration_name = "OpenVINOExecutionProvider"; + + void SetUp() override { +#ifndef _WIN32 + GTEST_SKIP() << "Skipping OpenVINO EP tests as the OpenVINO plugin is not built."; +#endif + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + } + + void TearDown() override { +#ifndef _WIN32 + GTEST_SKIP() << "Skipping OpenVINO EP tests as the OpenVINO plugin is not built."; +#endif + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); + } + + void RunModelWithSession(Ort::Session& session) { + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + std::vector input0_data(6, 2.0f); + std::vector ort_inputs; + std::vector ort_input_names; + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_input_names.push_back("X"); + std::array output_names{"Y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); + } + + void RunModelWithPluginEp(Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + RunModelWithSession(session); + } + + void GenerateEpContextOnLegacyPath(std::filesystem::path epctx, bool embed_mode) { + Ort::SessionOptions session_options{}; + std::filesystem::remove(epctx); + // Add config option to enable EP context + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + session_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, epctx.string().c_str()); + session_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, embed_mode ? "1" : "0"); + session_options.AppendExecutionProvider_OpenVINO_V2({{"device_type", "CPU"}}); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + RunModelWithSession(session); + } + + void GenerateEpContextOnPluginPath(std::filesystem::path epctx, bool embed_mode) { + Ort::SessionOptions session_options{}; + std::filesystem::remove(epctx); + // Add config option to enable EP context + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + session_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, epctx.string().c_str()); + session_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, embed_mode ? "1" : "0"); + Ort::ConstEpDevice plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + RunModelWithSession(session); + } + + Ort::ConstEpDevice GetOvCpuEpDevice(std::string device_type = "CPU") { + auto ep_devices = ort_env->GetEpDevices(); + Ort::ConstEpDevice plugin_ep_device{}; + + for (Ort::ConstEpDevice& device : ep_devices) { + if (device.Device().Type() == OrtHardwareDeviceType_CPU && + std::string_view(device.EpName()).find(registration_name) != std::string::npos) { + const auto& meta_kv = device.EpMetadata().GetKeyValuePairs(); + auto device_type_it = meta_kv.find("ov_device"); + if (device_type_it != meta_kv.end()) { + if (device_type_it->second == device_type) { + plugin_ep_device = device; + break; + } + } + } + } + + return plugin_ep_device; + } +}; + +TEST_F(OrtEpLibraryOv, LoadUnloadPluginLibrary) { + auto ep_devices = ort_env->GetEpDevices(); + auto test_cpu_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(test_cpu_ep_device, nullptr); + ASSERT_STREQ(test_cpu_ep_device.EpVendor(), "Intel"); + Ort::ConstHardwareDevice device = test_cpu_ep_device.Device(); + ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU); + ASSERT_GE(device.VendorId(), 0); + ASSERT_GE(device.DeviceId(), 0); + ASSERT_NE(device.Vendor(), nullptr); + std::unordered_map ep_metadata_entries = test_cpu_ep_device.EpMetadata().GetKeyValuePairs(); + ASSERT_GT(ep_metadata_entries.size(), 0); + ASSERT_GT(ep_metadata_entries.count("ov_device"), 0); +} + +TEST_F(OrtEpLibraryOv, MetaDevicesAvailable) { + auto ep_devices = ort_env->GetEpDevices(); + auto expected_meta_devices = {"AUTO"}; + + for (auto& expected_meta_device : expected_meta_devices) { + std::string expected_ep_name = registration_name + "." + expected_meta_device; + auto it = std::find_if(ep_devices.begin(), ep_devices.end(), + [&](Ort::ConstEpDevice& device) { + return std::string_view(device.EpName()).find(expected_ep_name) != std::string::npos; + }); + bool meta_device_found = it != ep_devices.end(); + ASSERT_TRUE(meta_device_found) << "Expected to find " << expected_ep_name; + } +} + +TEST_F(OrtEpLibraryOv, RunSessionWithAllAUTODevices) { + auto ep_devices = ort_env->GetEpDevices(); + std::vector matching_devices; + + for (const auto& device : ep_devices) { + std::string ep_name = device.EpName(); + if (ep_name.find(registration_name) != std::string::npos && + (ep_name == registration_name + ".AUTO")) { + matching_devices.push_back(device); + } + } + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(*ort_env, matching_devices, std::unordered_map{}); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); +} + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_MulInference) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + RunModelWithPluginEp(session_options); +} + +TEST_F(OrtEpLibraryOv, PluginEp_PreferCpu_MulInference) { + Ort::SessionOptions session_options; + session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); + RunModelWithPluginEp(session_options); +} + +struct EpCtxTestCases { + const ORTCHAR_T* ctx_filename; + bool embed_mode; +}; + +static const std::vector ep_context_cases = { + {ORT_TSTR("mul_1_ctx_cpu_embed1.onnx"), true}, + {ORT_TSTR("mul_1_ctx_cpu_embed0.onnx"), false}, + {ORT_TSTR("testdata/mul_1_ctx_cpu_embed0.onnx"), false}}; + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_cpu_epctx_variants) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + + for (const auto& test_case : ep_context_cases) { + GenerateEpContextOnLegacyPath(test_case.ctx_filename, test_case.embed_mode); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + Ort::Session session(*ort_env, test_case.ctx_filename, session_options); + RunModelWithSession(session); + } +} + +TEST_F(OrtEpLibraryOv, PluginEp_CheckV2DisallowedProviderOptions) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + std::vector> disallowed_provider_option_examples = { + {{"device_type", "CPU"}}, + {{"device_id", "CPU"}}, + {{"device_luid", "1234"}}, + {{"cache_dir", "cache"}}, + {{"precision", "F32"}}, + {{"context", "4"}}, + {{"num_of_threads", "1"}}, + {{"model_priority", "DEFAULT"}}, + {{"num_streams", "1"}}, + {{"enable_opencl_throttling", "true"}}, + {{"enable_qdq_optimizer", "true"}}, + {{"disable_dynamic_shapes", "true"}}, + }; + for (auto& example : disallowed_provider_option_examples) { + EXPECT_THROW({ + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, example); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); }, Ort::Exception); + } +} + +TEST_F(OrtEpLibraryOv, GenerateEpContextEmbedded) { + GenerateEpContextOnPluginPath(ORT_TSTR("mul_1_ctx_cpu_embed1.onnx"), true); +} + +TEST_F(OrtEpLibraryOv, GenerateEpContext) { + GenerateEpContextOnPluginPath(ORT_TSTR("mul_1_ctx_cpu_embed0.onnx"), false); +} + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_cpu_epctx_plugin_roundtrip_variants) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + + for (const auto& test_case : ep_context_cases) { + if (test_case.embed_mode) { + // TODO(ericcraw) Re-enable. + // Skip the embed mode until upstream fix. + continue; + } + + GenerateEpContextOnPluginPath(test_case.ctx_filename, test_case.embed_mode); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + Ort::Session session(*ort_env, test_case.ctx_filename, session_options); + RunModelWithSession(session); + } +} + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_cpu_epctx_plugin_roundtrip_variants_absolute) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + + for (const auto& test_case : ep_context_cases) { + if (test_case.embed_mode) { + // TODO(ericcraw) Re-enable. + // Skip the embed mode until upstream fix. + continue; + } + + auto absolute_path = std::filesystem::absolute(test_case.ctx_filename).native(); + GenerateEpContextOnPluginPath(absolute_path.c_str(), test_case.embed_mode); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + Ort::Session session(*ort_env, absolute_path.c_str(), session_options); + RunModelWithSession(session); + } +} + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_multiple_devices) { + auto plugin_ep_device = GetOvCpuEpDevice(); + ASSERT_NE(plugin_ep_device, nullptr); + + std::vector multi_device_list(2, plugin_ep_device); // 2 copies of cpu device. + + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(*ort_env, multi_device_list, std::unordered_map{}); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); +} + +TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_mixed_factory_devices_throw_exception) { + auto ep_devices = ort_env->GetEpDevices(); + std::vector matching_devices; + + for (const auto& device : ep_devices) { + std::string ep_name = device.EpName(); + if (ep_name.find(registration_name) != std::string::npos && + (ep_name == registration_name || ep_name == registration_name + ".AUTO")) { + matching_devices.push_back(device); + } + } + + ASSERT_GT(matching_devices.size(), 1) << "Expected more than one matching EP device"; + + EXPECT_THROW({ + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(*ort_env, matching_devices, std::unordered_map{}); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); }, Ort::Exception); +} From 3faf7d972d91f351ca792de30422ca04c6a397c6 Mon Sep 17 00:00:00 2001 From: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Date: Wed, 23 Jul 2025 10:12:20 -0700 Subject: [PATCH 071/138] ov_factory: Use 'GPU_DEVICE_ID' property to match with ORT device_id (#759) * ov_factory: Use 'GPU_DEVICE_ID' property to match with ORT device_id * clean up comment --- onnxruntime/core/providers/openvino/ov_factory.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_factory.cc b/onnxruntime/core/providers/openvino/ov_factory.cc index 94e8cbfeca8d1..746d6cc44cb66 100644 --- a/onnxruntime/core/providers/openvino/ov_factory.cc +++ b/onnxruntime/core/providers/openvino/ov_factory.cc @@ -96,12 +96,12 @@ OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice* const auto& ov_device_type = device_it->second; std::string ov_device_name; - auto get_pci_device_id = [&](const std::string& ov_device) { + auto get_gpu_device_id = [&](const std::string& ov_device) { try { - ov::device::PCIInfo pci_info = ov_core_->get_property(ov_device, ov::device::pci_info); - return pci_info.device; + auto device_id_str = ov_core_->get_property(ov_device, "GPU_DEVICE_ID").as(); + return static_cast(std::stoul(device_id_str, nullptr, 0)); } catch (ov::Exception&) { - return 0u; // If we can't get the PCI info, we won't have a device ID. + return 0u; // If we can't get the GPU_DEVICE_ID info, we won't have a device ID. } }; @@ -111,7 +111,7 @@ OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice* // If there are multiple devices of the same type, we need to match by device ID. matched_device = std::find_if(filtered_devices.begin(), filtered_devices.end(), [&](const std::string& ov_device) { uint32_t ort_device_id = ort_api.HardwareDevice_DeviceId(&device); - return ort_device_id == get_pci_device_id(ov_device); + return ort_device_id == get_gpu_device_id(ov_device); }); } From 2306b4a73e283445b694e166495f7dbd3c433afe Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Fri, 25 Jul 2025 11:47:06 +0530 Subject: [PATCH 072/138] [OVEP] Fix for upsample optype (#761) --- .../core/providers/openvino/ov_versions/data_ops.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 336b294117cba..17e69ad080b90 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -469,15 +469,7 @@ void DataOps::populate_op_mode_supported() { } } - // check for input dimensions const auto& x_arg = node->InputDefs()[0]; - auto shape = x_arg->Shape(); - if (shape != nullptr) { - // input tensor rank cannot be of one dimension - if (shape->dim_size() == 1 || shape->dim_size() == 4) { - return true; - } - } // x_arg supports only float, int8 and float16 type if ((x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) || From f4da9f13b98fd4f9e502dc86d3bc0734b36467dd Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Fri, 1 Aug 2025 11:31:47 +0530 Subject: [PATCH 073/138] [OVEP] Remove checks from load_config (#765) --- .../openvino/backends/basic_backend.cc | 31 +++---------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 8b7309e6a5a98..6efd866d47c3c 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -274,31 +274,14 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { return devices; }; - // Check if a property is supported and mutable - auto is_supported_and_mutable = [&](const std::string& key, - const std::vector& supported_config) -> bool { - auto it = std::find_if(supported_config.begin(), supported_config.end(), [&](const ov::PropertyName& property) { - return property == key && property.is_mutable(); - }); - return it != supported_config.end(); - }; - - // Set properties if they are valid, else log a warning if the property is missing or immutable by skipping the same - auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options, - const std::vector& supported_properties) { + // Set properties, Validation will be handled by OpenVINO Core + auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options) { for (const auto& [key, value] : config_options) { if ((key.find("NPUW") != std::string::npos) || ((device_config.find(key) != device_config.end()) && session_context_.enable_causallm)) { continue; } - if (is_supported_and_mutable(key, supported_properties)) { - OVCore::Get()->core.set_property(device, ov::AnyMap{{key, value}}); - } else { - LOGS_DEFAULT(WARNING) << "WARNING: Property \"" << key - << "\" is either unsupported in current OpenVINO version" - << " or property is immutable for target device \"" - << device << "\". Skipping setting this property."; - } + OVCore::Get()->core.set_property(device, ov::AnyMap{{key, value}}); } }; @@ -317,18 +300,14 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { // Set properties only for individual devices (e.g., "CPU", "GPU") for (const std::string& device : individual_devices) { if (target_config.count(device)) { - // Get supported properties for each individual device - auto device_properties = OVCore::Get()->core.get_property(device, ov::supported_properties); // Set properties for the device - set_target_properties(device, target_config.at(device), device_properties); + set_target_properties(device, target_config.at(device)); } } } else { if (target_config.count(session_context_.device_type)) { - auto supported_properties = OVCore::Get()->core.get_property(session_context_.device_type, - ov::supported_properties); set_target_properties(session_context_.device_type, - target_config.at(session_context_.device_type), supported_properties); + target_config.at(session_context_.device_type)); } } } From ed9e42506b171523d270afaf56cf8e3749cac386 Mon Sep 17 00:00:00 2001 From: "Klimenko, Mikhail" Date: Fri, 1 Aug 2025 15:08:23 +0200 Subject: [PATCH 074/138] Add self-detecting on-the-fly bfloat16->float16 conversion pass (#741) * Add on-the-fly bfloat16->float16 conversion pass * Fix undetected bfloat16 initializers * Remove the option and make the logic implicit * Add tests * Rename detection function * Fix CI for strict aliasing rules --------- Co-authored-by: Vishnudas Thaniel S --- .../providers/openvino/backend_manager.cc | 22 ++++ .../openvino/ov_versions/data_ops.cc | 7 +- .../qdq_transformations/qdq_scales_fix.cpp | 50 ++++++++ .../qdq_transformations/qdq_scales_fix.h | 5 + .../openvino_ep_bfloat16_pass_test.cc | 116 ++++++++++++++++++ 5 files changed, 197 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index be59b1ae07020..cadeab4cbd4cc 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -375,6 +375,18 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { return false; } +static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) { + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (std::size_t i = 0; i < node_indices.size(); i++) { + gsl::not_null node(graph_viewer.GetNode(node_indices[i])); + for (auto& output : node->OutputDefs()) { + if (output->ToProto().type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + return true; + } + } + return false; +} + static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, [[maybe_unused]] const onnxruntime::Node& fused_node) { @@ -456,6 +468,16 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; + } else if (IsModelBF16(subgraph)) { + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP bfloat16->float16 optimization pass is enabled"; + std::unique_ptr model; + Status status = bfloat16_fix::Transform(subgraph, logger, model); + auto model_proto = model->ToProto(); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + print_model_proto_duration(); + DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); + ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); + return model_proto; } else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled"; auto model = subgraph.CreateModel(logger); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 17e69ad080b90..f991e85ebe518 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -555,8 +555,11 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { return false; } + auto dtype = type_proto->tensor_type().elem_type(); + // Enable bfloat16 -> float16 on-the-fly conversion + if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) + return true; if (is_initializer) { - auto dtype = type_proto->tensor_type().elem_type(); for (auto const& var : supported_types_initializer_) { if ((var.first <= version_id_) && (var.second == dtype)) { @@ -571,8 +574,6 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { #endif return false; } else { - auto dtype = type_proto->tensor_type().elem_type(); - if (device_id_.find("HETERO") != std::string::npos || device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) { for (auto const& var : supported_types_npu_) { diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index d159930d52845..f1ce230387565 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -3,6 +3,7 @@ #include "qdq_scales_fix.h" #include "core/providers/openvino/ov_protobuf_utils.h" +#include "core/framework/float16.h" #include #include @@ -940,5 +941,54 @@ Status Transform(const GraphViewer& src_graph_viewer, return status; } } // namespace qdq_scales_fix + +namespace bfloat16_fix { +void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) { + for (auto& const_node : gen_graph.original_graph.Nodes()) { + auto node = const_cast(const_node); + if (node->OpType() == "Cast") { + for (auto& [name, const_attribute] : node->GetAttributes()) { + auto& attribute = const_cast(const_attribute); + if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT) + if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + } + } + for (auto& output : node->OutputDefs()) { + auto& output_proto = const_cast(output->ToProto().type()); + if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + } + } + + const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors(); + for (auto& [key, const_tensor_proto] : init_set) { + auto tensor_proto = const_cast(const_tensor_proto); + auto dt = tensor_proto->data_type(); + if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { + auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast(tensor_proto->mutable_raw_data()->data()) : nullptr; + if (raw_data) { + tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + std::int64_t size = 1; + for (int i = 0; i < tensor_proto->dims_size(); ++i) + size *= tensor_proto->dims()[i]; + for (std::int64_t i = 0; i < size; ++i) { + raw_data[i] = onnxruntime::MLFloat16(onnxruntime::BFloat16::FromBits(raw_data[i])).val; + } + } + } + } +} + +Status Transform(const GraphViewer& src_graph_viewer, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model) { + auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model); + auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph()); + + replace_bf16_with_fp16(g); + return status; +} +} // namespace bfloat16_fix } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h index c54c531e1bd40..2182850d96c43 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h @@ -15,5 +15,10 @@ Status Transform(const GraphViewer& src_graph, const logging::Logger& logger, /*out*/ std::unique_ptr& model); } +namespace bfloat16_fix { +Status Transform(const GraphViewer& src_graph, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model); +} } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc new file mode 100644 index 0000000000000..fc90563a61bb1 --- /dev/null +++ b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/float16.h" + +#include "test/util/include/test/test_environment.h" +#include "test/optimizer/qdq_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + +extern std::unique_ptr ort_env; + +class OVEP_BF16_Tests : public ::testing::TestWithParam {}; + +namespace detail { +auto ConstructModel() { + using namespace onnxruntime; + using namespace test; + + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 19; + Model model("Bfloat16Tester", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); + + Graph& graph = model.MainGraph(); + ModelTestBuilder builder(graph); + auto dim = 4; + std::vector input_data(dim, 1.0f); + auto* input = builder.MakeInput({dim}, input_data); + builder.graph_.SetInputs({input}); + + auto* cast_to_bf16 = builder.MakeIntermediate(); + Node& cast_node = builder.AddNode("Cast", {input}, {cast_to_bf16}, ""); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)); + + std::vector weight_data(dim * dim); + for (std::size_t i = 0; i < weight_data.size(); ++i) + weight_data[i] = onnxruntime::BFloat16(static_cast(i % dim) / dim); + auto* weights = builder.MakeInitializer({dim, dim}, weight_data); + + auto* matmul_out = builder.MakeIntermediate(); + builder.AddNode("MatMul", {cast_to_bf16, weights}, {matmul_out}); + + std::vector weight_data_2(dim * dim); + for (std::size_t i = 0; i < weight_data_2.size(); ++i) + weight_data_2[i] = onnxruntime::BFloat16(static_cast(i % dim) / dim); + auto* weights_2 = builder.MakeInitializer({dim, dim}, weight_data_2); + + auto* matmul_out_2 = builder.MakeIntermediate(); + builder.AddNode("MatMul", {matmul_out, weights_2}, {matmul_out_2}); + + auto* output = builder.MakeOutput(); + Node& cast2_node = builder.AddNode("Cast", {matmul_out_2}, {output}); + cast2_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + + builder.SetGraphOutputs(); + auto st = model.MainGraph().Resolve(); + if (st != Status::OK()) + throw std::runtime_error(st.ErrorMessage()); + return model; +} + +auto ProbeDevice(const std::string& device) { + static std::map is_present; + if (is_present.find(device) == is_present.end()) { + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + ov_options["device_type"] = device; + try { + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + is_present[device] = true; + } catch (...) { + is_present[device] = false; + } + } + return is_present[device]; +} +} // namespace detail + +namespace onnxruntime { +namespace test { + +TEST_P(OVEP_BF16_Tests, TestModelConversion) { + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + const auto& device = GetParam(); + if (!::detail::ProbeDevice(device)) + GTEST_SKIP() << device + " is not available on this machine"; + + ov_options["device_type"] = device; + auto model = ::detail::ConstructModel(); + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + try { + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), sessionOptions); + } catch (...) { + FAIL(); + } +} +INSTANTIATE_TEST_SUITE_P(OVEP_Tests, + OVEP_BF16_Tests, + ::testing::Values("CPU", "GPU", "NPU")); +} // namespace test +} // namespace onnxruntime From 47a231ad8175fe426a5a0a632fe4c32af7d3a923 Mon Sep 17 00:00:00 2001 From: n1harika Date: Thu, 31 Jul 2025 21:48:36 -0700 Subject: [PATCH 075/138] [OVEP] Mild weight sharing- quantization paramters are kept as initialisers --- .../qdq_transformations/qdq_stripping.cc | 53 +++++++++++++++---- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index 24e8892622175..7f88879a7a456 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -677,6 +677,27 @@ static void AddInitializerAsInput(onnxruntime::Graph& dst_graph, } } +// To check if the input parameters of a DQ or Q node are quantization parameters +// Scale and Zero point parameters are quantization parameters +static bool IsQuantizationParameter(const std::string& initializer_name, + const onnxruntime::GraphViewer& src_graph) { + // Check if this initializer is used as scale or zero_point in any DQ/Q node + for (auto& node_idx : src_graph.GetNodesInTopologicalOrder()) { + const auto* node = src_graph.GetNode(node_idx); + if (node->OpType() == "DequantizeLinear" || node->OpType() == "QuantizeLinear") { + const auto& input_defs = node->InputDefs(); + // Check if this initializer is used as scale (input 1) or zero_point (input 2) + if (input_defs.size() >= 2 && input_defs[1]->Name() == initializer_name) { + return true; // This is a scale parameter + } + if (input_defs.size() >= 3 && input_defs[2]->Name() == initializer_name) { + return true; // This is a zero_point parameter + } + } + } + return false; +} + // Creates a new model without the DQ/Q operators in the src graph. Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, @@ -845,19 +866,31 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, if (!init_with_data && utils::HasExternalData(initializer_tensor) && enable_ovep_weight_sharing) { - insert_metadata(initializer_tensor); - // Add initializer with external data as input - AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); - } else { - // Add as an initialized tensor if it does not have external data - if (initializers_to_keep.count(name) > 0) { - if (init_with_data) { - dst_graph.AddInitializedTensor(*init_with_data); + // Only convert to input if it's not a quantization parameter + bool is_quant_param = IsQuantizationParameter(name, src_graph); + + if (!is_quant_param) { + // This is actual weight data - so to convert to input for weight sharing + insert_metadata(initializer_tensor); + AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); } else { - dst_graph.AddInitializedTensor(initializer_tensor); + // This is a quantization parameter - keep as initializer even if external + + if (initializers_to_keep.count(name) > 0) { + + dst_graph.AddInitializedTensor(initializer_tensor); + } + } + } else { + // Add as an initialized tensor if it does not have external data + if (initializers_to_keep.count(name) > 0) { + if (init_with_data) { + dst_graph.AddInitializedTensor(*init_with_data); + } else { + dst_graph.AddInitializedTensor(initializer_tensor); + } } - } } current_scope_initializer_set.insert(name); From e4f8acb09159ce0fb9b94ffbf94f3eeee649dea3 Mon Sep 17 00:00:00 2001 From: TejalKhade28 Date: Fri, 11 Jul 2025 15:42:38 +0530 Subject: [PATCH 076/138] Cluster Change to avoid Dangling DQLinear --- .../openvino/ov_versions/capability.cc | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 2309ff3de751b..06fc26c44ed75 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -166,17 +166,33 @@ std::vector> GetCapability::Execute() { auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters); int no_of_clusters = 0; + std::vector prev_cluster; + bool try_next_cluster = false; for (auto this_cluster : connected_clusters) { + bool omit_subgraph = false; + if (try_next_cluster) { + // no need to check previous cluster + for (auto idx : prev_cluster) { + if ((std::find(this_cluster.begin(), this_cluster.end(), idx)) == this_cluster.end()) { + this_cluster.emplace_back(idx); + } + } + try_next_cluster = false; + } + // If subgraph has less then three, graph is considered trivial unless its an epctx cluster - if (this_cluster.size() < 3) { + if (!try_next_cluster && this_cluster.size() < 3) { bool is_epctx_node = false; for (auto node_idx : this_cluster) { if (graph_viewer_.GetNode(node_idx)->OpType() == "EPContext") is_epctx_node = true; } - if (!is_epctx_node) - continue; + if (!is_epctx_node) { + omit_subgraph = true; + prev_cluster = this_cluster; + try_next_cluster = true; + } } std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; From aa31709b3a0d382a8d41af3663662cf4a480a59d Mon Sep 17 00:00:00 2001 From: TejalKhade28 Date: Fri, 11 Jul 2025 18:18:05 +0530 Subject: [PATCH 077/138] Error in subgraph --- onnxruntime/core/providers/openvino/ov_versions/capability.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 06fc26c44ed75..4ad8f8dd85f4d 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -204,7 +204,7 @@ std::vector> GetCapability::Execute() { cluster_inputs, cluster_outputs); - bool omit_subgraph = false; + // Omitting zero dim subgraphs for (auto index : this_cluster) { const Node* node = graph_viewer_.GetNode(index); From 725744a2f33b2429eff7bf1f0c652e89430000f5 Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Wed, 13 Aug 2025 22:50:53 +0530 Subject: [PATCH 078/138] Fix to set precision from config (#778) Not setting default precision if it is not set via provider option. --- .../core/providers/openvino/openvino_provider_factory.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 9dba8623031d0..ce5269d9298b6 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -224,7 +224,9 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, pi.cache_dir = provider_options.at("cache_dir"); } - pi.precision = OpenVINOParserUtils::ParsePrecision(provider_options, pi.device_type, "precision"); + if (provider_options.contains("precision")) { + pi.precision = OpenVINOParserUtils::ParsePrecision(provider_options, pi.device_type, "precision"); + } if (provider_options.contains("reshape_input")) { pi.reshape = OpenVINOParserUtils::ParseInputShape(provider_options.at("reshape_input")); From 609dfbf885f243ba69bb1c36478fcefd999a8b21 Mon Sep 17 00:00:00 2001 From: liang Date: Thu, 14 Aug 2025 02:05:15 +0800 Subject: [PATCH 079/138] Fix the load_config not work when set INFERENCE_PRECISION_HINT (#777) --- .../core/providers/openvino/backends/basic_backend.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 6efd866d47c3c..f023e064f98ee 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -277,8 +277,14 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { // Set properties, Validation will be handled by OpenVINO Core auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options) { for (const auto& [key, value] : config_options) { + // Update the device_config map from the target_config to avoid load_config being overridden + // by the device_config set by the OpenVINO EP. + auto it = device_config.find(key); + if (it != device_config.end()) { + it->second = value; + } if ((key.find("NPUW") != std::string::npos) || - ((device_config.find(key) != device_config.end()) && session_context_.enable_causallm)) { + ((it != device_config.end()) && session_context_.enable_causallm)) { continue; } OVCore::Get()->core.set_property(device, ov::AnyMap{{key, value}}); From a780d5b31a840a3f4ba56471ae37d45d004069b8 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Wed, 13 Aug 2025 11:47:06 -0700 Subject: [PATCH 080/138] =?UTF-8?q?Fix=20failing=20case=20where=20input=20?= =?UTF-8?q?onnx=20model=20is=20used=20with=20shared=20context=20e=E2=80=A6?= =?UTF-8?q?=20(#776)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix failing case where input onnx model is used with shared context enabled * Update onnxruntime/core/providers/openvino/openvino_execution_provider.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../core/providers/openvino/openvino_execution_provider.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 1b19517b07363..a0fa885cbfc38 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -94,18 +94,23 @@ common::Status OpenVINOExecutionProvider::Compile( auto& logger = *GetLogger(); Status status = Status::OK(); + bool is_epctx_model = false; if (!fused_nodes.empty()) { // Assume these properties are constant for all the model subgraphs, otherwise move to SubGraphContext const auto& graph_body_viewer_0 = fused_nodes[0].filtered_graph.get(); session_context_.onnx_model_path_name = graph_body_viewer_0.ModelPath().string(); session_context_.onnx_opset_version = graph_body_viewer_0.DomainToVersionMap().at(kOnnxDomain); + + // OVIR wrapped in epctx should be treated as source but this code does not + // This corner case is not in use and will be addressed in a future commit + is_epctx_model = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(graph_body_viewer_0); } // The block below is executed during EP context model inference auto& metadata = shared_context_->shared_weights.metadata; // Metadata object in memory if (session_context_.so_share_ep_contexts && - !session_context_.so_context_enable && + is_epctx_model && metadata.empty()) { fs::path context_model_file_path = session_context_.so_context_file_path; if (context_model_file_path.empty()) { From a6359eedc644b723e7a524f4653a34116d2802ab Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Wed, 13 Aug 2025 22:31:06 -0700 Subject: [PATCH 081/138] [OVEP] Support for providing layout to input/output to OpenVINO (#767) * [OVEP] Support for providing layout to input/output to OpenVINO * [OVEP] Minor bug fixes for layout feature --- .../core/providers/openvino/backend_utils.cc | 40 ++++++++++ .../core/providers/openvino/backend_utils.h | 2 + .../openvino/backends/basic_backend.cc | 1 + .../core/providers/openvino/contexts.h | 4 +- .../openvino/openvino_parser_utils.cc | 74 +++++++++++++++++++ .../openvino/openvino_parser_utils.h | 2 + .../openvino/openvino_provider_factory.cc | 4 + .../test/perftest/command_args_parser.cc | 4 +- onnxruntime/test/perftest/ort_test_session.cc | 4 +- 9 files changed, 132 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 73fbe9a0fa76f..7027861f0c4dc 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -150,6 +150,11 @@ CreateOVModel(std::string&& model, LOGS_DEFAULT(INFO) << log_tag << "Reshaping the ov tensor to specified shape"; ov_model->reshape(session_context.reshape); } + + if (!session_context.layout.empty()) { + LOGS_DEFAULT(INFO) << log_tag << "Setting the ov tensor layout to specified layout"; + ov_model = Set_Layout(ov_model, session_context.layout); + } // Check for Constant Folding if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) { ov::pass::ConstantFolding pass_const_obj; @@ -199,6 +204,41 @@ GetOutputTensor(Ort::KernelContext& context, return context.GetOutput(index, output_shape); } +std::shared_ptr Set_Layout(std::shared_ptr ov_model, const layout_t& layout) { + ov::preprocess::PrePostProcessor preproc(ov_model); + + const auto& inputs = ov_model->inputs(); + const auto& outputs = ov_model->outputs(); + + auto find_tensor_index = [](const std::vector>& tensors, const std::string& name) -> std::optional { + for (size_t i = 0; i < tensors.size(); ++i) { + const auto& tensor = tensors[i]; + if (tensor.get_any_name() == name || tensor.get_tensor().get_names().count(name) > 0) { + return i; + } + } + return std::nullopt; + }; + + for (const auto& [tensor_name, layout_value] : layout) { + bool tensor_found = false; + + if (auto input_idx = find_tensor_index(inputs, tensor_name)) { + preproc.input(*input_idx).tensor().set_layout(layout_value); + tensor_found = true; + } else if (auto output_idx = find_tensor_index(outputs, tensor_name)) { + preproc.output(*output_idx).tensor().set_layout(layout_value); + tensor_found = true; + } + + if (!tensor_found) { + LOGS_DEFAULT(WARNING) << "Tensor '" << tensor_name << "' not found in model inputs or outputs"; + } + } + + return preproc.build(); +} + int GetFirstAvailableDevice(SessionContext& session_context) { int i = 0; // Get the first available VAD-M device and set the device to busy diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index 15145df651fa2..27f791c7a5bd1 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -79,6 +79,8 @@ int GetFirstAvailableDevice(SessionContext& session_context); void FillOutputsWithConstantData(std::shared_ptr node, Ort::UnownedValue& out_tensor); +std::shared_ptr Set_Layout(std::shared_ptr ov_model, const layout_t& layout); + template void FillOutputHelper(Ort::UnownedValue& out_tensor, std::shared_ptr node); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index f023e064f98ee..93d9c3276ab97 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -98,6 +98,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr !subgraph_context_.has_dynamic_input_shape && !session_context_.so_context_enable && session_context_.reshape.empty() && + session_context_.layout.empty() && !enable_causallm && !eligible_for_cpu_fallback && auto_unified_compile); diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 6a2b375d733f9..07b09899ac214 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -70,6 +70,7 @@ class SharedContext : public WeakSingleton { using config_t = std::map; using reshape_t = std::map; +using layout_t = std::map; struct ProviderInfo { std::string device_type{""}; // [device_type]: Overrides the accelerator hardware type and @@ -88,6 +89,7 @@ struct ProviderInfo { // (GPU) feature. If blob files are already present, // it will be directly loaded. reshape_t reshape{}; // Used for reshaping the ov input tensor shape at runtime. + layout_t layout{}; // Used for specifying the ov input/output tensor layout at runtime. std::string model_priority{"DEFAULT"}; // High-level OpenVINO model priority hint // Defines what model should be provided with more performant // bounded resource first @@ -110,7 +112,7 @@ struct ProviderInfo { const ConfigOptions* config_options{NULL}; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", - "enable_causallm", "disable_dynamic_shapes", "reshape_input"}; + "enable_causallm", "disable_dynamic_shapes", "reshape_input", "layout"}; }; // Holds context applicable to the entire EP instance. diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc index 21fc7f935da23..a290fea73e0e8 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc @@ -236,5 +236,79 @@ ov::Dimension OpenVINOParserUtils::ParseDimensionRange(const std::string& range_ return ov::Dimension(range_start, range_end); } +layout_t OpenVINOParserUtils::ParseLayout(const std::string& layout_definition) { + layout_t parsed_layout_map; + + // Return empty map for empty input + if (layout_definition.empty()) { + ORT_THROW("Empty layout definition provided in layout parameter"); + } + + // Regular expression for parsing layout definitions + const std::regex layout_pattern(R"(([^\[\],]+)\s*\[(.*?)\])"); // e.g. "input_1[NC],data[CHW]" + + // Find all tensor layout definitions using regex + auto layout_begin = std::sregex_iterator( + layout_definition.begin(), + layout_definition.end(), + layout_pattern); + auto layout_end = std::sregex_iterator(); + + // If no matches found, throw error + if (layout_begin == layout_end) { + ORT_THROW("Invalid layout definition format: " + layout_definition); + } + + // Process each tensor definition + for (std::sregex_iterator i = std::move(layout_begin); i != layout_end; ++i) { + std::smatch layout_match = *i; + + // Extract tensor name and trim whitespace + std::string tensor_name = layout_match[1].str(); // Group 1: tensor name e.g. "input_1" + tensor_name = TrimWhitespace(tensor_name); + + if (tensor_name.empty()) { + ORT_THROW("Empty tensor name provided in layout parameter"); + } + + // Extract dimensions string + std::string dimensions_str = layout_match[2].str(); // Group 2: dimensions string [e.g. "NC", "CHW"] + + if (!Check_Valid_Layout(dimensions_str, tensor_name)) { + ORT_THROW("Invalid dimensions string provided in layout parameter"); + } + + // Store parsed shape in result map + parsed_layout_map[tensor_name] = ov::Layout(dimensions_str); + } + + return parsed_layout_map; +} + +bool OpenVINOParserUtils::Check_Valid_Layout(const std::string& layout_str, const std::string& tensor_name) { + // Check if the layout string is empty + if (layout_str.empty()) { + return false; + } + + std::unordered_set seen_alphabets; + for (char c : layout_str) { + if (std::isalpha(c)) { + char upper_c = static_cast(std::toupper(c)); // Convert to uppercase for case-insensitive comparison + if (seen_alphabets.find(upper_c) != seen_alphabets.end()) { + ORT_THROW("Repeated Dim '" + std::string(1, c) + + "' found in layout dimensions for tensor '" + tensor_name + "'"); + } + seen_alphabets.insert(upper_c); + } else if (c != '?') { + // Only '?' is allowed as non-alphabetic character + ORT_THROW("Invalid character '" + std::string(1, c) + + "' found in layout dimensions for tensor '" + tensor_name + "'"); + } + } + + return true; +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.h b/onnxruntime/core/providers/openvino/openvino_parser_utils.h index e6aa0e0a46a3b..a0936d627df40 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.h +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.h @@ -18,8 +18,10 @@ class OpenVINOParserUtils { std::string& device_type, const std::string& option_name); static reshape_t ParseInputShape(const std::string& reshape_input_definition); + static layout_t ParseLayout(const std::string& layout_definition); static std::string TrimWhitespace(const std::string& str); static ov::Dimension ParseDimensionRange(const std::string& range_str, const std::string& tensor_name); + static bool Check_Valid_Layout(const std::string& layout_str, const std::string& tensor_name); }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index ce5269d9298b6..bebdb25ccc058 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -232,6 +232,10 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, pi.reshape = OpenVINOParserUtils::ParseInputShape(provider_options.at("reshape_input")); } + if (provider_options.contains("layout")) { + pi.layout = OpenVINOParserUtils::ParseLayout(provider_options.at("layout")); + } + if (provider_options.contains("load_config")) { auto parse_config = [&](const std::string& config_str) -> std::map { // If the config string is empty, return an empty map and skip processing diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 5c81696d5c57e..1e76bcc7bc386 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -60,7 +60,9 @@ ABSL_FLAG(std::string, i, "", " [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" " [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" " [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" - " [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" + " [OpenVINO only] [reshape_input]: Sets model input shapes with support for bounded dynamic dimensions using 'min..max' syntax (e.g., [1..10,3,224,224]) \n" + " [OpenVINO only] [layout]: Specifies the layout for inputs/outputs to interpret tensor dimensions correctly. \n" + " [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU num_of_threads|5 enable_opencl_throttling|true reshape_input|[1,3,60,60..100] layout|[NCHW] cache_dir|\"\"\"\n" "\n" " [QNN only] [backend_type]: QNN backend type. E.g., 'cpu', 'htp'. Mutually exclusive with 'backend_path'.\n" " [QNN only] [backend_path]: QNN backend path. E.g., '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. Mutually exclusive with 'backend_type'.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 7156a1eb5c347..1026cfe41182c 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -906,12 +906,14 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ov_options[key] = value; } else if (key == "reshape_input") { ov_options[key] = value; + } else if (key == "layout") { + ov_options[key] = value; } else { ORT_THROW( "[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO." " ['device_type', 'device_id', 'num_of_threads', 'load_config', 'cache_dir', 'num_streams', " "'enable_opencl_throttling', 'disable_dynamic_shapes', 'enable_qdq_optimizer'," - " 'enable_causallm', 'model_priority'] \n"); + " 'enable_causallm', 'reshape_input', 'layout', 'model_priority'] \n"); } } session_options.AppendExecutionProvider_OpenVINO_V2(ov_options); From e6346544d41d8349b06421d78d217b68ddc0a89f Mon Sep 17 00:00:00 2001 From: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Date: Thu, 21 Aug 2025 05:27:58 -0700 Subject: [PATCH 082/138] OVInferRequest::SetTensor: Set tensor upon cached_binding shape mismatch (#783) --- .../core/providers/openvino/ov_interface.h | 29 +++++-------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 6d1db4366410b..59c2cf95874b0 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -126,29 +126,16 @@ class OVInferRequest { OVTensorPtr GetTensor(const std::string& name); std::string GetInputTensorName(uint32_t index); - // Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set. + // Set tensor call infer req tensor if ort_ptr differs from last set ptr. void SetTensor(const std::string& name, const ov::element::Type& type, const ov::Shape& shape, void* ort_ptr) { auto& cached_binding = bindings_cache_[name]; - if (cached_binding.ort_ptr != ort_ptr) { - auto tensor_ptr = std::make_shared(type, shape, const_cast(ort_ptr)); - SetTensor(name, tensor_ptr); - cached_binding = {tensor_ptr, ort_ptr}; - } else if (ort_ptr == nullptr) { - // a null ort_ptr is expected for a tensor that has 0 elements. - // for example, a tensor of shape=[1, 8, 0, 64], which is valid. - // So, we check to see if at least one shape entry is 0. - auto contains_zero = [](const ov::Shape& shape) { - for (auto& s : shape) - if (s == 0) return true; - return false; - }; - if (contains_zero(shape)) { - // if there are zero elements (i.e. at least one shape entry is 0), - // then create and set the tensor anyway. - auto tensor_ptr = std::make_shared(type, shape); - SetTensor(name, tensor_ptr); - cached_binding = {tensor_ptr, ort_ptr}; - } + if (cached_binding.ort_ptr != ort_ptr || + !cached_binding.tensor_ptr || + cached_binding.tensor_ptr->get_shape() != shape) { + cached_binding.tensor_ptr.reset(); + auto ov_tensor = std::make_shared(type, shape, const_cast(ort_ptr)); + ovInfReq.set_tensor(name, *ov_tensor); + cached_binding = {ov_tensor, ort_ptr}; } } From 0bad3d7ec19f61678842a704509fd4131ad2652c Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Sat, 23 Aug 2025 05:11:59 +0530 Subject: [PATCH 083/138] Updated load_config mapping to make it a passthrough to OV properties (#782) --- .../openvino/backends/basic_backend.cc | 110 +++--------------- 1 file changed, 16 insertions(+), 94 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 93d9c3276ab97..a0c6b0c5984e1 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -214,107 +214,29 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { if (!session_context_.load_config.empty()) { const std::map& target_config = session_context_.load_config; - if ((session_context_.device_type.find("NPU") != std::string::npos) && session_context_.enable_causallm) { - if (target_config.find("NPU") != target_config.end()) { - auto npu_genai_config = target_config.at("NPU"); - CausalLMConfig().ApplyConfig(npu_genai_config, device_config); - } else { - LOGS_DEFAULT(WARNING) << "ORT GenAI CausalLMConfig Configuration not found."; - } - } - - if (session_context_.device_type.find("NPU") != std::string::npos) { - auto npuw_config = target_config.at("NPU"); - - // Check if "NPU_USE_NPUW" exists and is set to "YES" - auto npu_use_npuw_it = npuw_config.find("NPU_USE_NPUW"); - if (npu_use_npuw_it != npuw_config.end() && - npu_use_npuw_it->second.is() && - npu_use_npuw_it->second.as() == "YES") { - // Only add NPUW-related keys if NPU_USE_NPUW is "YES" - for (const auto& [key, value] : npuw_config) { - if (key.find("NPUW") != std::string::npos) { - if (!value.is()) { - LOGS_DEFAULT(ERROR) << "Invalid value type for key: " << key; - continue; - } - device_config[key] = value; - } - } - } else { - // Check if there are any "NPUW" keys and log a warning - if (std::any_of(npuw_config.begin(), npuw_config.end(), - [&](const auto& pair) { return pair.first.find("NPUW") != std::string::npos; })) { - LOGS_DEFAULT(WARNING) << "Skipping NPUW-related configurations as NPU_USE_NPUW is not set to 'YES'."; - } - } - } - auto find_device_type_mode = [&](const std::string& device_type) -> std::string { - std::string device_mode = ""; - auto delimiter_pos = device_type.find(':'); - if (delimiter_pos != std::string::npos) { - std::stringstream str_stream(device_type.substr(0, delimiter_pos)); - std::getline(str_stream, device_mode, ','); - } - return device_mode; - }; + // Extract device names from device string and apply their configs + // Examples: "GPU" -> ["GPU"], "AUTO:GPU.0,CPU" -> ["AUTO", "GPU", "CPU"] + auto apply_device_config = [&](std::string_view device) { + if (device.empty()) return; - // Parse device types like "AUTO:CPU,GPU" and extract individual devices - auto parse_individual_devices = [&](const std::string& device_type) -> std::vector { - std::vector devices; - auto delimiter_pos = device_type.find(':'); - if (delimiter_pos != std::string::npos) { - std::stringstream str_stream(device_type.substr(delimiter_pos + 1)); - std::string device; - while (std::getline(str_stream, device, ',')) { - devices.emplace_back(device); - } - } else { - devices.emplace_back(device_type); - } - return devices; - }; + // Remove device index: "GPU.0" -> "GPU" + auto base_device = device.substr(0, device.find('.')); - // Set properties, Validation will be handled by OpenVINO Core - auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options) { - for (const auto& [key, value] : config_options) { - // Update the device_config map from the target_config to avoid load_config being overridden - // by the device_config set by the OpenVINO EP. - auto it = device_config.find(key); - if (it != device_config.end()) { - it->second = value; - } - if ((key.find("NPUW") != std::string::npos) || - ((it != device_config.end()) && session_context_.enable_causallm)) { - continue; + if (auto config_it = target_config.find(std::string(base_device)); config_it != target_config.end()) { + for (const auto& [key, value] : config_it->second) { + device_config[key] = value; } - OVCore::Get()->core.set_property(device, ov::AnyMap{{key, value}}); } }; - // Check if the device type is AUTO, HETERO, or MULTI - if (session_context_.device_type.find("AUTO") == 0 || - session_context_.device_type.find("HETERO") == 0 || - session_context_.device_type.find("MULTI") == 0) { - //// Parse to get the device mode (e.g., "AUTO:CPU,GPU" -> "AUTO") - std::unordered_set supported_mode = {"AUTO", "HETERO", "MULTI"}; - auto device_mode = find_device_type_mode(session_context_.device_type); - ORT_ENFORCE(supported_mode.find(device_mode) != supported_mode.end(), " Invalid device mode is passed : ", session_context_.device_type); - // Parse individual devices (e.g., "AUTO:CPU,GPU" -> ["CPU", "GPU"]) - auto individual_devices = parse_individual_devices(session_context_.device_type); - if (!device_mode.empty()) individual_devices.emplace_back(device_mode); - - // Set properties only for individual devices (e.g., "CPU", "GPU") - for (const std::string& device : individual_devices) { - if (target_config.count(device)) { - // Set properties for the device - set_target_properties(device, target_config.at(device)); + // Parse device string by splitting on ':' and ',' delimiters + const auto& device_str = session_context_.device_type; + for (size_t start = 0, pos = 0; pos <= device_str.size(); ++pos) { + if (pos == device_str.size() || device_str[pos] == ':' || device_str[pos] == ',') { + if (pos > start) { + apply_device_config(std::string_view(device_str).substr(start, pos - start)); } - } - } else { - if (target_config.count(session_context_.device_type)) { - set_target_properties(session_context_.device_type, - target_config.at(session_context_.device_type)); + start = pos + 1; } } } From 8ecdbd06c3e7c493f18faa0caa8ca5088ce9776e Mon Sep 17 00:00:00 2001 From: "Klimenko, Mikhail" Date: Mon, 25 Aug 2025 10:04:59 +0200 Subject: [PATCH 084/138] Fix model copying for QDQ stripping (#784) * Reintroduce #768 with a small fix * Fix model copying with help from microsoft#25761 * Remove unused debug variables --- .../providers/openvino/backend_manager.cc | 44 ++++++++++++++++++- .../openvino/ov_versions/data_ops.cc | 7 ++- .../qdq_transformations/qdq_scales_fix.cpp | 22 +++------- 3 files changed, 52 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index cadeab4cbd4cc..2af414bd359bf 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -387,6 +387,44 @@ static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) { return false; } +static bool Is16BitTensor(const onnxruntime::NodeArg* node_arg) { + const auto* type_proto = node_arg ? node_arg->TypeAsProto() : nullptr; + return type_proto && type_proto->has_tensor_type() && + (type_proto->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT16 || + type_proto->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_INT16); +} + +// Check to see if the graph has Q/DQ nodes with int16 or uint16 quantization +static bool IsQDQGraphWithUint16OrInt16(const onnxruntime::GraphViewer& graph_viewer) { + std::unordered_set qdq_ops = {"QuantizeLinear", "DequantizeLinear"}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + + for (size_t i = 0; i < node_indices.size(); i++) { + gsl::not_null node(graph_viewer.GetNode(node_indices[i])); + + if (qdq_ops.find(node->OpType()) != qdq_ops.end()) { + const auto& input_defs = node->InputDefs(); + + if (node->OpType() == "DequantizeLinear") { + // DequantizeLinear: [quantized_input, scale, zero_point] -> [float_output] + // Check quantized input tensor and optional zero point + if (Is16BitTensor(input_defs.empty() ? nullptr : input_defs[0]) || + (input_defs.size() >= 3 && Is16BitTensor(input_defs[2]))) { + return true; + } + } else if (node->OpType() == "QuantizeLinear") { + // QuantizeLinear: [float_input, scale, zero_point] -> [quantized_output] + const auto& output_defs = node->OutputDefs(); + if (Is16BitTensor(output_defs.empty() ? nullptr : output_defs[0]) || + (input_defs.size() >= 3 && Is16BitTensor(input_defs[2]))) { + return true; + } + } + } + } + return false; +} + static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, [[maybe_unused]] const onnxruntime::Node& fused_node) { @@ -445,6 +483,10 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, } #endif + // Check if the graph is QDQ and has int16 or uint16 quantization + // If so, we will apply the QDQ scales fix transformation (for GPU device only) + bool is_qdq_graph_uint16_or_int16 = IsQDQGraphWithUint16OrInt16(subgraph); + const auto& onnx_model_path_name = subgraph.ModelPath(); // QDQ stripping enabled only for the NPU and experimentally on the GPU if ((session_context_.device_type.find("NPU") != std::string::npos) && @@ -458,7 +500,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; } else if ((session_context_.device_type.find("GPU") != std::string::npos) && - enable_ovep_qdq_optimizer) { + is_qdq_graph_uint16_or_int16) { // Create a copy of the model std::unique_ptr model; Status status = qdq_scales_fix::Transform(subgraph, logger, model); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index f991e85ebe518..3b25d67b6b376 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -557,7 +557,9 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { auto dtype = type_proto->tensor_type().elem_type(); // Enable bfloat16 -> float16 on-the-fly conversion - if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) + if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16 || + dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || + dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16) return true; if (is_initializer) { for (auto const& var : supported_types_initializer_) { @@ -610,9 +612,6 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { (var.second == dtype)) { return true; } - // experimentally for GPU and qdq stripping mode allow int16 types - if (npu_qdq_optimizer_enabled_ && (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16)) - return true; } #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index f1ce230387565..3a39152b5d17d 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -3,6 +3,7 @@ #include "qdq_scales_fix.h" #include "core/providers/openvino/ov_protobuf_utils.h" +#include "core/framework/ort_value.h" #include "core/framework/float16.h" #include @@ -904,22 +905,11 @@ Status copy_model(const GraphViewer& src_graph_viewer, } for (auto& [name, tensor_proto] : src_graph.GetAllInitializedTensors()) { - dst_graph.AddInitializedTensor(*tensor_proto); - } - - for (auto node_arg : src_graph.GetInputsIncludingInitializers()) { - auto check_inputs = [node_arg](auto input_node_arg) { - return input_node_arg->Name() == node_arg->Name(); - }; - if (std::find_if(dst_graph_inputs.begin(), dst_graph_inputs.end(), check_inputs) != dst_graph_inputs.end()) - continue; - - auto src_tensor_proto = src_graph.GetConstantInitializer(node_arg->Name(), true); - if (src_tensor_proto) { - auto dst_tensor_proto = onnx::TensorProto::Create(); - dst_tensor_proto->copy_from(src_tensor_proto); - dst_graph.AddInitializedTensor(*dst_tensor_proto); - } + auto ort_value = OrtValue(); + if (src_graph.GetOrtValueInitializer(name, ort_value)) + ORT_RETURN_IF_ERROR(dst_graph.AddInitializedOrtValue(*tensor_proto, ort_value)); + else + dst_graph.AddInitializedTensor(*tensor_proto); } ORT_RETURN_IF_ERROR(dst_graph.Resolve()); From e974fbeef7dec4a113f10c501b7e91eaf48b1f15 Mon Sep 17 00:00:00 2001 From: sfatimar Date: Mon, 25 Aug 2025 15:45:50 +0530 Subject: [PATCH 085/138] Sahar/psu lora fix 2 (#788) * Changed fix * Fix to omit subgraph * Commit a fix for cluster index len * Fixing the Warning with size_t on clusters * Loop Test fix --- .../openvino/ov_versions/capability.cc | 51 ++++++++----------- .../providers/openvino/ov_versions/utils.cc | 20 ++++++++ .../providers/openvino/ov_versions/utils.h | 4 ++ .../providers/cpu/controlflow/loop_test.cc | 2 +- 4 files changed, 47 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 4ad8f8dd85f4d..4be4bff039df4 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -166,33 +166,24 @@ std::vector> GetCapability::Execute() { auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters); int no_of_clusters = 0; - std::vector prev_cluster; - bool try_next_cluster = false; - + size_t cluster_index = 0; + size_t total_clusters = connected_clusters.size(); for (auto this_cluster : connected_clusters) { bool omit_subgraph = false; - if (try_next_cluster) { - // no need to check previous cluster - for (auto idx : prev_cluster) { - if ((std::find(this_cluster.begin(), this_cluster.end(), idx)) == this_cluster.end()) { - this_cluster.emplace_back(idx); - } - } - try_next_cluster = false; - } - // If subgraph has less then three, graph is considered trivial unless its an epctx cluster - if (!try_next_cluster && this_cluster.size() < 3) { - bool is_epctx_node = false; - for (auto node_idx : this_cluster) { - if (graph_viewer_.GetNode(node_idx)->OpType() == "EPContext") - is_epctx_node = true; - } - if (!is_epctx_node) { - omit_subgraph = true; - prev_cluster = this_cluster; - try_next_cluster = true; - } + //auto id = this_cluster.at(0); + if (this_cluster.size() == 1) { + //check next cluster + auto index = this_cluster.at(0); + if (graph_viewer_.GetNode(index)->OpType() == "EPContext") { + omit_subgraph=false; + } else if(cluster_index < total_clusters-1) { + bool append_node = AddTrivialClusterToNextClusterIfConnected(graph_viewer_, index, connected_clusters[cluster_index+1]); + if(append_node) { + connected_clusters[cluster_index+1].emplace_back(index); + } + omit_subgraph=true; + } } std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; @@ -233,15 +224,17 @@ std::vector> GetCapability::Execute() { } } } - if (omit_subgraph) - continue; /* In scenarios, when there are no inputs or all inputs being initializers, ConstantFolding optimization in onnxruntime pre-computes the value.*/ - if (!cluster_inputs.empty()) { - AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); - no_of_clusters++; + if (!omit_subgraph) { + if (!cluster_inputs.empty()) { + AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); + no_of_clusters++; + } } + + cluster_index = cluster_index+1; } LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Supported subgraphs on OpenVINO: " << no_of_clusters; } diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index f924fa0c8205c..814378eab47d5 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -153,6 +153,26 @@ GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector& search_cluster) { + + for(auto index: search_cluster) { + auto curr_node = graph_viewer.GetNode(index); + for (auto node = curr_node->InputNodesBegin(); node != curr_node->InputNodesEnd(); ++node) { + if((*node).Index() == curr_node_index) + return true; + } + + for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) { + if((*node).Index() == curr_node_index) + return true; + } + } + return false; +} + + void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::vector& cluster, const std::unordered_set& ng_required_initializers, diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.h b/onnxruntime/core/providers/openvino/ov_versions/utils.h index 34aa762ba9b67..bdad047a422c1 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.h +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.h @@ -40,6 +40,10 @@ void IdentifyConnectedNodes( std::vector> GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector>& clusters); +bool AddTrivialClusterToNextClusterIfConnected(const GraphViewer& graph_viewer, + const NodeIndex index, + const std::vector& search_cluster); + void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::vector& cluster, const std::unordered_set& ng_required_initializers, diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index a5fd37361a255..a92c1ed47f69b 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -1162,7 +1162,7 @@ TEST(Loop, SequenceAsLoopCarriedDependency) { test.AddSeqOutput("loop_var_0_final", seq_output); // Disable TensorRT on unsupported data type BOOL - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } #if !defined(DISABLE_OPTIONAL_TYPE) From ed6d8e0314c6e51a1329411d7c8ca535983f7d24 Mon Sep 17 00:00:00 2001 From: sfatimar Date: Mon, 25 Aug 2025 23:07:08 +0530 Subject: [PATCH 086/138] Commit PSU Lora fix (#791) --- .../providers/openvino/ov_versions/capability.cc | 12 ++++++++---- .../test/providers/cpu/controlflow/loop_test.cc | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 4be4bff039df4..593a78491080a 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -171,16 +171,20 @@ std::vector> GetCapability::Execute() { for (auto this_cluster : connected_clusters) { bool omit_subgraph = false; - //auto id = this_cluster.at(0); if (this_cluster.size() == 1) { //check next cluster auto index = this_cluster.at(0); + size_t j = cluster_index; if (graph_viewer_.GetNode(index)->OpType() == "EPContext") { omit_subgraph=false; - } else if(cluster_index < total_clusters-1) { - bool append_node = AddTrivialClusterToNextClusterIfConnected(graph_viewer_, index, connected_clusters[cluster_index+1]); + } else if(j < total_clusters-1) { + bool append_node = false; + while(j Date: Mon, 25 Aug 2025 14:11:02 -0700 Subject: [PATCH 087/138] Fix to resolve EPCtx filename confusion with wrapped OVIRs (#793) * Fix to resolve EPCtx filename confusion with wrapped OVIRs * Update onnxruntime/core/providers/openvino/backend_manager.cc ORT_THROW can take more than one argument, however, the code style does seem to be one error string per ORT THROW, so we can accept this update. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../core/providers/openvino/backend_manager.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 2af414bd359bf..e1899405629c8 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -90,7 +90,12 @@ BackendManager::BackendManager(SessionContext& session_context, "[OpenVINO-EP] Bounded dynamic model execution using provider option reshape_input is not supported for OVEP EPContext model"; ORT_THROW(exception_str); } - model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph); + if (subgraph_context_.is_ep_ctx_ovir_encapsulated) { + model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.onnx_model_path_name.replace_extension("xml").string(), subgraph); + } else { + model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph); + } + } else { model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger); } @@ -236,7 +241,9 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie std::ofstream blob_file(blob_filename, std::ios::out | std::ios::trunc | std::ios::binary); if (!blob_file) { - ORT_THROW("Unable to open file for epctx model dump."); + std::ostringstream err_msg; + err_msg << "Unable to open file for epctx model dump: " << blob_filename; + ORT_THROW(err_msg.str()); } compiled_model.export_model(blob_file); model_blob_str = blob_filename.filename().string(); From 866a24cfe740375da4f05fcc56e09512a727e695 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Mon, 25 Aug 2025 17:35:21 -0700 Subject: [PATCH 088/138] Allow mmapping native binaries in ov 2025.3 (#794) Reduces the total commit while running the model on NPU device. --- .../providers/openvino/backends/basic_backend.cc | 2 +- onnxruntime/core/providers/openvino/ibackend.h | 2 +- .../providers/openvino/onnx_ctx_model_helper.cc | 10 +++++++--- .../providers/openvino/onnx_ctx_model_helper.h | 8 +++++++- .../core/providers/openvino/ov_interface.cc | 14 ++++++++++++-- onnxruntime/core/providers/openvino/ov_interface.h | 3 ++- 6 files changed, 30 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index a0c6b0c5984e1..2f174110dd31b 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -59,7 +59,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr }; // If the EPContext node with OVIR Encapsulation, then create // an executable network from EP_CACHE_CONTEXT using read_model() & compile_model() - exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream, + exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream->stream_, hw_target, device_config, enable_causallm, diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h index ec38425f602eb..365a4625815d6 100644 --- a/onnxruntime/core/providers/openvino/ibackend.h +++ b/onnxruntime/core/providers/openvino/ibackend.h @@ -19,7 +19,7 @@ class IBackend { virtual ~IBackend() = default; virtual void RewindKVCache(size_t index) {} }; -using ptr_stream_t = std::unique_ptr; +using ptr_stream_t = std::unique_ptr; class BackendFactory { public: static std::shared_ptr diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 9e70756a254aa..051a39bd4f205 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -100,7 +100,8 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, return Status::OK(); } -std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { +std::unique_ptr +EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); auto node = graph_viewer.GetNode(first_index); ORT_ENFORCE(node != nullptr); @@ -113,10 +114,11 @@ std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesy bool embed_mode = static_cast(attrs.at(EMBED_MODE).i()); std::unique_ptr result; + std::filesystem::path blob_filepath{}; if (embed_mode) { result.reset((std::istream*)new std::istringstream(ep_cache_context)); } else { - auto blob_filepath = so_context_file_path; + blob_filepath = so_context_file_path; if (blob_filepath.empty() && !graph_viewer.ModelPath().empty()) { blob_filepath = graph_viewer.ModelPath(); } @@ -126,16 +128,18 @@ std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesy } bool isXML = backend_utils::IsModelStreamXML(*result); + std::filesystem::path native_blob_path{}; if (!isXML) { // If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was // exported with must match the version that is currently running. + native_blob_path = std::move(blob_filepath); ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_), "EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() + ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); } LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; - return result; + return std::make_unique(std::move(result), native_blob_path); } bool EPCtxHandler::CheckForOVEPCtxNodeInGraph(const GraphViewer& graph_viewer) const { diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index b9ddb40a7a233..f207f5014ca1f 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -12,6 +12,12 @@ namespace onnxruntime { namespace openvino_ep { +struct ModelBlobWrapper { + ModelBlobWrapper(std::unique_ptr stream, const std::filesystem::path& native_blob_path) : stream_(std::move(stream)), maybe_native_blob_path_(native_blob_path) {} + std::unique_ptr stream_; + std::filesystem::path maybe_native_blob_path_; +}; + // Utilities to handle EPContext node export and parsing of an EPContext node // to create the compiled_model object to infer on static const char EPCONTEXT_OP[] = "EPContext"; @@ -31,7 +37,7 @@ class EPCtxHandler { const std::string& graph_name, const bool embed_mode, std::string&& model_blob_str) const; - std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const; + std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const; InlinedVector GetEPCtxNodes() const; bool CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 2d29df8eb4197..899845d4890cf 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -11,6 +11,7 @@ #include "core/providers/openvino/backend_utils.h" #include "core/providers/openvino/backends/basic_backend.h" #include "core/providers/openvino/ov_stateful_patch_utils.h" +#include "core/providers/openvino/onnx_ctx_model_helper.h" namespace onnxruntime { namespace openvino_ep { @@ -191,14 +192,23 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, "Exception while Loading Network for graph {}", name); } -OVExeNetwork OVCore::ImportModel(std::istream& model_stream, +OVExeNetwork OVCore::ImportModel(ModelBlobWrapper& model_blob, std::string hw_target, const ov::AnyMap& device_config, std::string name) { return OvExceptionBoundary([&]() { ov::CompiledModel obj; - obj = core.import_model(model_stream, hw_target, device_config); +#if (OPENVINO_VERSION_MAJOR > 2025 || (OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR >= 3)) + if (!model_blob.maybe_native_blob_path_.empty()) { + obj = core.import_model(ov::read_tensor_data(model_blob.maybe_native_blob_path_), hw_target, device_config); + } else { + obj = core.import_model(*model_blob.stream_, hw_target, device_config); + } +#else + obj = core.import_model(*model_blob.stream_, hw_target, device_config); +#endif OVExeNetwork exe(obj, hw_target); + #ifndef NDEBUG printDebugInfo(exe.Get()); #endif diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 59c2cf95874b0..3e1f829258608 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -26,6 +26,7 @@ namespace openvino_ep { class OVCore; class OVInferRequest; class OVExeNetwork; +struct ModelBlobWrapper; typedef ov::Tensor OVTensor; typedef ov::ProfilingInfo OVProfilingInfo; @@ -82,7 +83,7 @@ struct OVCore : WeakSingleton { ov::AnyMap& device_config, const std::string& name); // OV Interface for Import model Stream - OVExeNetwork ImportModel(std::istream& model_stream, + OVExeNetwork ImportModel(ModelBlobWrapper& model_blob, std::string hw_target, const ov::AnyMap& device_config, std::string name); From 7eafccb9609a8711bb8d0fad6e4e19b45755eb5b Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Tue, 26 Aug 2025 12:11:57 +0530 Subject: [PATCH 089/138] fix: lint fixes (#795) --- .../providers/openvino/backend_manager.cc | 2 +- .../openvino/ov_versions/capability.cc | 37 +++++----- .../providers/openvino/ov_versions/utils.cc | 22 +++--- .../qdq_transformations/qdq_stripping.cc | 68 +++++++++---------- 4 files changed, 62 insertions(+), 67 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index e1899405629c8..68d15bdfdcee0 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -95,7 +95,7 @@ BackendManager::BackendManager(SessionContext& session_context, } else { model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph); } - + } else { model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger); } diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 593a78491080a..1893700cab09c 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -166,28 +166,28 @@ std::vector> GetCapability::Execute() { auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters); int no_of_clusters = 0; - size_t cluster_index = 0; - size_t total_clusters = connected_clusters.size(); + size_t cluster_index = 0; + size_t total_clusters = connected_clusters.size(); for (auto this_cluster : connected_clusters) { bool omit_subgraph = false; if (this_cluster.size() == 1) { - //check next cluster - auto index = this_cluster.at(0); - size_t j = cluster_index; - if (graph_viewer_.GetNode(index)->OpType() == "EPContext") { - omit_subgraph=false; - } else if(j < total_clusters-1) { - bool append_node = false; - while(jOpType() == "EPContext") { + omit_subgraph = false; + } else if (j < total_clusters - 1) { + bool append_node = false; + while (j < total_clusters && !append_node) { + j = j + 1; + append_node = AddTrivialClusterToNextClusterIfConnected(graph_viewer_, index, connected_clusters[j]); } + if (append_node) { + connected_clusters[j].emplace_back(index); + } + omit_subgraph = true; + } } std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; @@ -199,7 +199,6 @@ std::vector> GetCapability::Execute() { cluster_inputs, cluster_outputs); - // Omitting zero dim subgraphs for (auto index : this_cluster) { const Node* node = graph_viewer_.GetNode(index); @@ -238,7 +237,7 @@ std::vector> GetCapability::Execute() { } } - cluster_index = cluster_index+1; + cluster_index = cluster_index + 1; } LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Supported subgraphs on OpenVINO: " << no_of_clusters; } diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index 814378eab47d5..791341218913f 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -156,23 +156,21 @@ GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector& search_cluster) { + for (auto index : search_cluster) { + auto curr_node = graph_viewer.GetNode(index); + for (auto node = curr_node->InputNodesBegin(); node != curr_node->InputNodesEnd(); ++node) { + if ((*node).Index() == curr_node_index) + return true; + } - for(auto index: search_cluster) { - auto curr_node = graph_viewer.GetNode(index); - for (auto node = curr_node->InputNodesBegin(); node != curr_node->InputNodesEnd(); ++node) { - if((*node).Index() == curr_node_index) - return true; - } - - for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) { - if((*node).Index() == curr_node_index) - return true; - } + for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) { + if ((*node).Index() == curr_node_index) + return true; + } } return false; } - void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::vector& cluster, const std::unordered_set& ng_required_initializers, diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index 7f88879a7a456..e010851f22e50 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -680,22 +680,22 @@ static void AddInitializerAsInput(onnxruntime::Graph& dst_graph, // To check if the input parameters of a DQ or Q node are quantization parameters // Scale and Zero point parameters are quantization parameters static bool IsQuantizationParameter(const std::string& initializer_name, - const onnxruntime::GraphViewer& src_graph) { - // Check if this initializer is used as scale or zero_point in any DQ/Q node - for (auto& node_idx : src_graph.GetNodesInTopologicalOrder()) { - const auto* node = src_graph.GetNode(node_idx); - if (node->OpType() == "DequantizeLinear" || node->OpType() == "QuantizeLinear") { - const auto& input_defs = node->InputDefs(); - // Check if this initializer is used as scale (input 1) or zero_point (input 2) - if (input_defs.size() >= 2 && input_defs[1]->Name() == initializer_name) { - return true; // This is a scale parameter - } - if (input_defs.size() >= 3 && input_defs[2]->Name() == initializer_name) { - return true; // This is a zero_point parameter - } - } + const onnxruntime::GraphViewer& src_graph) { + // Check if this initializer is used as scale or zero_point in any DQ/Q node + for (auto& node_idx : src_graph.GetNodesInTopologicalOrder()) { + const auto* node = src_graph.GetNode(node_idx); + if (node->OpType() == "DequantizeLinear" || node->OpType() == "QuantizeLinear") { + const auto& input_defs = node->InputDefs(); + // Check if this initializer is used as scale (input 1) or zero_point (input 2) + if (input_defs.size() >= 2 && input_defs[1]->Name() == initializer_name) { + return true; // This is a scale parameter + } + if (input_defs.size() >= 3 && input_defs[2]->Name() == initializer_name) { + return true; // This is a zero_point parameter + } } - return false; + } + return false; } // Creates a new model without the DQ/Q operators in the src graph. @@ -866,31 +866,29 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, if (!init_with_data && utils::HasExternalData(initializer_tensor) && enable_ovep_weight_sharing) { + // Only convert to input if it's not a quantization parameter + bool is_quant_param = IsQuantizationParameter(name, src_graph); - // Only convert to input if it's not a quantization parameter - bool is_quant_param = IsQuantizationParameter(name, src_graph); - - if (!is_quant_param) { - // This is actual weight data - so to convert to input for weight sharing - insert_metadata(initializer_tensor); - AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); - } else { - // This is a quantization parameter - keep as initializer even if external - - if (initializers_to_keep.count(name) > 0) { + if (!is_quant_param) { + // This is actual weight data - so to convert to input for weight sharing + insert_metadata(initializer_tensor); + AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); + } else { + // This is a quantization parameter - keep as initializer even if external - dst_graph.AddInitializedTensor(initializer_tensor); - } + if (initializers_to_keep.count(name) > 0) { + dst_graph.AddInitializedTensor(initializer_tensor); } + } } else { - // Add as an initialized tensor if it does not have external data - if (initializers_to_keep.count(name) > 0) { - if (init_with_data) { - dst_graph.AddInitializedTensor(*init_with_data); - } else { - dst_graph.AddInitializedTensor(initializer_tensor); - } + // Add as an initialized tensor if it does not have external data + if (initializers_to_keep.count(name) > 0) { + if (init_with_data) { + dst_graph.AddInitializedTensor(*init_with_data); + } else { + dst_graph.AddInitializedTensor(initializer_tensor); } + } } current_scope_initializer_set.insert(name); From e019aa1531f63a3ab5fd9afdb0da9a41d569910f Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Tue, 26 Aug 2025 16:11:36 +0530 Subject: [PATCH 090/138] Update operator support status for OpenVINO 2025.2 (#792) * Update operator support status for OpenVINO 2025.2 * Disable unsupported tests --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- onnxruntime/core/providers/openvino/ov_versions/data_ops.cc | 6 +++++- .../providers/cpu/tensor/dynamic_quantize_linear_test.cc | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 3b25d67b6b376..f848b89ed10c8 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -121,6 +121,7 @@ std::vector supported_op_mode = { {"DepthToSpace", V_2020_4, {"CPU", "GPU"}}, {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}}, {"DequantizeLinear", V_2024_4, {"NPU"}}, + {"DynamicQuantizeLinear", V_2025_2, {"CPU", "GPU"}}, {"DynamicQuantizeMatMul", V_2025_0, {"CPU", "GPU"}}, {"Div", V_2020_4, {"CPU", "GPU"}}, {"Dropout", V_2020_4, {"CPU", "GPU"}}, @@ -172,6 +173,7 @@ std::vector supported_op_mode = { {"LSTM", V_2020_4, {"CPU", "GPU"}}, {"MatMul", V_2020_4, {"CPU", "GPU"}}, {"MatMulInteger", V_2022_1, {"CPU"}}, + {"MatMulInteger", V_2025_2, {"GPU"}}, {"MatMulNBits", V_2024_5, {"CPU", "GPU"}}, {"Max", V_2020_4, {"CPU", "GPU"}}, {"MaxPool", V_2020_4, {"CPU", "GPU"}}, @@ -191,7 +193,7 @@ std::vector supported_op_mode = { {"Pad", V_2020_4, {"CPU", "GPU"}}, {"Pow", V_2020_4, {"CPU", "GPU"}}, {"PRelu", V_2020_4, {"CPU", "GPU"}}, - {"QLinearMatMul", V_2022_3, {"CPU"}}, + // {"QLinearMatMul", V_2022_3, {"CPU"}}, {"QuantizeLinear", V_2021_4, {"CPU", "GPU"}}, {"QuickGelu", V_2025_0, {"CPU", "GPU"}}, {"RNN", V_2023_1, {"CPU", "GPU"}}, @@ -361,6 +363,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Clip", V_2022_1, {"All"}}); no_dimension_supported_.push_back({"Div", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"DequantizeLinear", V_2021_4, {"All"}}); + no_dimension_supported_.push_back({"DynamicQuantizeLinear", V_2025_2, {"All"}}); no_dimension_supported_.push_back({"Equal", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); no_dimension_supported_.push_back({"Expand", V_2023_3, {"CPU"}}); @@ -374,6 +377,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}}); no_dimension_supported_.push_back({"Max", V_2024_4, {"All"}}); no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"MatMulInteger", V_2025_2, {"All"}}); no_dimension_supported_.push_back({"Mul", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Neg", V_2023_0, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"Pow", V_2023_0, {"CPU", "GPU"}}); diff --git a/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc index f4d8cad90a714..1a71da6d95135 100644 --- a/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc @@ -11,7 +11,8 @@ namespace test { // range = [-ve, +ve] TEST(QuantizeLinearOpTest, DynamicQuantizeLinear) { // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { + if (DefaultDmlExecutionProvider().get() != nullptr || + DefaultOpenVINOExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected equality of these values: 26 and 25"; } From 167f2ad87faf08fda00f287e01f55eaae3fe4ee9 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Wed, 27 Aug 2025 20:33:21 -0700 Subject: [PATCH 091/138] [OVEP] Fix cov issues (#796) --- .../core/providers/openvino/openvino_provider_factory.cc | 6 +++--- onnxruntime/core/providers/openvino/ov_factory.cc | 2 +- onnxruntime/core/providers/openvino/ov_interface.h | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index bebdb25ccc058..480e4c068664e 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -171,7 +171,7 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio if (!device_mode.empty()) { selected_device = device_mode + ":" + ov_luid_devices; for (const auto& dev_str : devices_to_check) { - const auto default_dev = split(dev_str, '.')[0]; + const std::string default_dev = split(dev_str, '.')[0]; if (ov_luid_devices.find(default_dev) == std::string::npos) selected_device = selected_device + "," + dev_str; @@ -532,7 +532,7 @@ struct OpenVINO_Provider : Provider { std::string ov_device_string; if (is_meta_device_factory) { // Build up a meta device string based on the devices that are passed in. E.g. AUTO:NPU,GPU.0,CPU - ov_device_string = ov_meta_device_type; + ov_device_string = std::move(ov_meta_device_type); ov_device_string += ":"; } @@ -545,7 +545,7 @@ struct OpenVINO_Provider : Provider { prepend_comma = true; } - provider_options["device_type"] = ov_device_string; + provider_options["device_type"] = std::move(ov_device_string); // Parse provider info with the device type ProviderInfo pi; diff --git a/onnxruntime/core/providers/openvino/ov_factory.cc b/onnxruntime/core/providers/openvino/ov_factory.cc index 8860405338409..2853cc17726ab 100644 --- a/onnxruntime/core/providers/openvino/ov_factory.cc +++ b/onnxruntime/core/providers/openvino/ov_factory.cc @@ -105,7 +105,7 @@ OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice* std::string ov_device_name; auto get_gpu_device_id = [&](const std::string& ov_device) { try { - auto device_id_str = ov_core_->get_property(ov_device, "GPU_DEVICE_ID").as(); + const std::string device_id_str = ov_core_->get_property(ov_device, "GPU_DEVICE_ID").as(); return static_cast(std::stoul(device_id_str, nullptr, 0)); } catch (ov::Exception&) { return 0u; // If we can't get the GPU_DEVICE_ID info, we won't have a device ID. diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 3e1f829258608..38ea883078e85 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -136,7 +136,7 @@ class OVInferRequest { cached_binding.tensor_ptr.reset(); auto ov_tensor = std::make_shared(type, shape, const_cast(ort_ptr)); ovInfReq.set_tensor(name, *ov_tensor); - cached_binding = {ov_tensor, ort_ptr}; + cached_binding = {std::move(ov_tensor), ort_ptr}; } } From b7244f1a217e73ef3a2413deb9e1abdbf16ce40d Mon Sep 17 00:00:00 2001 From: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Date: Thu, 28 Aug 2025 13:03:02 -0700 Subject: [PATCH 092/138] remove onnxruntime/test/providers/openvino/openvino_plugin.cc (#798) --- .../providers/openvino/openvino_plugin.cc | 302 ------------------ 1 file changed, 302 deletions(-) delete mode 100644 onnxruntime/test/providers/openvino/openvino_plugin.cc diff --git a/onnxruntime/test/providers/openvino/openvino_plugin.cc b/onnxruntime/test/providers/openvino/openvino_plugin.cc deleted file mode 100644 index 5abca55820a24..0000000000000 --- a/onnxruntime/test/providers/openvino/openvino_plugin.cc +++ /dev/null @@ -1,302 +0,0 @@ -#include -#include - -#include "gtest/gtest.h" -#include "core/common/common.h" -#include "core/session/onnxruntime_session_options_config_keys.h" -#include "onnxruntime_cxx_api.h" -#include "api_asserts.h" -#include "core/session/onnxruntime_session_options_config_keys.h" - -extern std::unique_ptr ort_env; - -struct OrtEpLibraryOv : public ::testing::Test { - static const inline std::filesystem::path library_path = -#if _WIN32 - "onnxruntime_providers_openvino.dll"; -#else - "libonnxruntime_providers_openvino.so"; -#endif - static const inline std::string registration_name = "OpenVINOExecutionProvider"; - - void SetUp() override { -#ifndef _WIN32 - GTEST_SKIP() << "Skipping OpenVINO EP tests as the OpenVINO plugin is not built."; -#endif - ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); - } - - void TearDown() override { -#ifndef _WIN32 - GTEST_SKIP() << "Skipping OpenVINO EP tests as the OpenVINO plugin is not built."; -#endif - ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); - } - - void RunModelWithSession(Ort::Session& session) { - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - std::vector shape = {3, 2}; - std::vector input0_data(6, 2.0f); - std::vector ort_inputs; - std::vector ort_input_names; - ort_inputs.emplace_back(Ort::Value::CreateTensor( - memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); - ort_input_names.push_back("X"); - std::array output_names{"Y"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); - Ort::Value& ort_output = ort_outputs[0]; - const float* output_data = ort_output.GetTensorData(); - gsl::span output_span(output_data, 6); - EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); - } - - void RunModelWithPluginEp(Ort::SessionOptions& session_options) { - Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); - RunModelWithSession(session); - } - - void GenerateEpContextOnLegacyPath(std::filesystem::path epctx, bool embed_mode) { - Ort::SessionOptions session_options{}; - std::filesystem::remove(epctx); - // Add config option to enable EP context - session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); - session_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, epctx.string().c_str()); - session_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, embed_mode ? "1" : "0"); - session_options.AppendExecutionProvider_OpenVINO_V2({{"device_type", "CPU"}}); - Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); - RunModelWithSession(session); - } - - void GenerateEpContextOnPluginPath(std::filesystem::path epctx, bool embed_mode) { - Ort::SessionOptions session_options{}; - std::filesystem::remove(epctx); - // Add config option to enable EP context - session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); - session_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, epctx.string().c_str()); - session_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, embed_mode ? "1" : "0"); - Ort::ConstEpDevice plugin_ep_device = GetOvCpuEpDevice(); - ASSERT_NE(plugin_ep_device, nullptr); - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); - Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); - RunModelWithSession(session); - } - - Ort::ConstEpDevice GetOvCpuEpDevice(std::string device_type = "CPU") { - auto ep_devices = ort_env->GetEpDevices(); - Ort::ConstEpDevice plugin_ep_device{}; - - for (Ort::ConstEpDevice& device : ep_devices) { - if (device.Device().Type() == OrtHardwareDeviceType_CPU && - std::string_view(device.EpName()).find(registration_name) != std::string::npos) { - const auto& meta_kv = device.EpMetadata().GetKeyValuePairs(); - auto device_type_it = meta_kv.find("ov_device"); - if (device_type_it != meta_kv.end()) { - if (device_type_it->second == device_type) { - plugin_ep_device = device; - break; - } - } - } - } - - return plugin_ep_device; - } -}; - -TEST_F(OrtEpLibraryOv, LoadUnloadPluginLibrary) { - auto ep_devices = ort_env->GetEpDevices(); - auto test_cpu_ep_device = GetOvCpuEpDevice(); - ASSERT_NE(test_cpu_ep_device, nullptr); - ASSERT_STREQ(test_cpu_ep_device.EpVendor(), "Intel"); - Ort::ConstHardwareDevice device = test_cpu_ep_device.Device(); - ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU); - ASSERT_GE(device.VendorId(), 0); - ASSERT_GE(device.DeviceId(), 0); - ASSERT_NE(device.Vendor(), nullptr); - std::unordered_map ep_metadata_entries = test_cpu_ep_device.EpMetadata().GetKeyValuePairs(); - ASSERT_GT(ep_metadata_entries.size(), 0); - ASSERT_GT(ep_metadata_entries.count("ov_device"), 0); -} - -TEST_F(OrtEpLibraryOv, MetaDevicesAvailable) { - auto ep_devices = ort_env->GetEpDevices(); - auto expected_meta_devices = {"AUTO"}; - - for (auto& expected_meta_device : expected_meta_devices) { - std::string expected_ep_name = registration_name + "." + expected_meta_device; - auto it = std::find_if(ep_devices.begin(), ep_devices.end(), - [&](Ort::ConstEpDevice& device) { - return std::string_view(device.EpName()).find(expected_ep_name) != std::string::npos; - }); - bool meta_device_found = it != ep_devices.end(); - ASSERT_TRUE(meta_device_found) << "Expected to find " << expected_ep_name; - } -} - -TEST_F(OrtEpLibraryOv, RunSessionWithAllAUTODevices) { - auto ep_devices = ort_env->GetEpDevices(); - std::vector matching_devices; - - for (const auto& device : ep_devices) { - std::string ep_name = device.EpName(); - if (ep_name.find(registration_name) != std::string::npos && - (ep_name == registration_name + ".AUTO")) { - matching_devices.push_back(device); - } - } - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider_V2(*ort_env, matching_devices, std::unordered_map{}); - Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); -} - -TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_MulInference) { - auto plugin_ep_device = GetOvCpuEpDevice(); - ASSERT_NE(plugin_ep_device, nullptr); - - Ort::SessionOptions session_options; - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); - RunModelWithPluginEp(session_options); -} - -TEST_F(OrtEpLibraryOv, PluginEp_PreferCpu_MulInference) { - Ort::SessionOptions session_options; - session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); - RunModelWithPluginEp(session_options); -} - -struct EpCtxTestCases { - const ORTCHAR_T* ctx_filename; - bool embed_mode; -}; - -static const std::vector ep_context_cases = { - {ORT_TSTR("mul_1_ctx_cpu_embed1.onnx"), true}, - {ORT_TSTR("mul_1_ctx_cpu_embed0.onnx"), false}, - {ORT_TSTR("testdata/mul_1_ctx_cpu_embed0.onnx"), false}}; - -TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_cpu_epctx_variants) { - auto plugin_ep_device = GetOvCpuEpDevice(); - ASSERT_NE(plugin_ep_device, nullptr); - - for (const auto& test_case : ep_context_cases) { - GenerateEpContextOnLegacyPath(test_case.ctx_filename, test_case.embed_mode); - - Ort::SessionOptions session_options; - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); - Ort::Session session(*ort_env, test_case.ctx_filename, session_options); - RunModelWithSession(session); - } -} - -TEST_F(OrtEpLibraryOv, PluginEp_CheckV2DisallowedProviderOptions) { - auto plugin_ep_device = GetOvCpuEpDevice(); - ASSERT_NE(plugin_ep_device, nullptr); - std::vector> disallowed_provider_option_examples = { - {{"device_type", "CPU"}}, - {{"device_id", "CPU"}}, - {{"device_luid", "1234"}}, - {{"cache_dir", "cache"}}, - {{"precision", "F32"}}, - {{"context", "4"}}, - {{"num_of_threads", "1"}}, - {{"model_priority", "DEFAULT"}}, - {{"num_streams", "1"}}, - {{"enable_opencl_throttling", "true"}}, - {{"enable_qdq_optimizer", "true"}}, - {{"disable_dynamic_shapes", "true"}}, - }; - for (auto& example : disallowed_provider_option_examples) { - EXPECT_THROW({ - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, example); - Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); }, Ort::Exception); - } -} - -TEST_F(OrtEpLibraryOv, GenerateEpContextEmbedded) { - GenerateEpContextOnPluginPath(ORT_TSTR("mul_1_ctx_cpu_embed1.onnx"), true); -} - -TEST_F(OrtEpLibraryOv, GenerateEpContext) { - GenerateEpContextOnPluginPath(ORT_TSTR("mul_1_ctx_cpu_embed0.onnx"), false); -} - -TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_cpu_epctx_plugin_roundtrip_variants) { - auto plugin_ep_device = GetOvCpuEpDevice(); - ASSERT_NE(plugin_ep_device, nullptr); - - for (const auto& test_case : ep_context_cases) { - if (test_case.embed_mode) { - // TODO(ericcraw) Re-enable. - // Skip the embed mode until upstream fix. - continue; - } - - GenerateEpContextOnPluginPath(test_case.ctx_filename, test_case.embed_mode); - - Ort::SessionOptions session_options; - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); - Ort::Session session(*ort_env, test_case.ctx_filename, session_options); - RunModelWithSession(session); - } -} - -TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_cpu_epctx_plugin_roundtrip_variants_absolute) { - auto plugin_ep_device = GetOvCpuEpDevice(); - ASSERT_NE(plugin_ep_device, nullptr); - - for (const auto& test_case : ep_context_cases) { - if (test_case.embed_mode) { - // TODO(ericcraw) Re-enable. - // Skip the embed mode until upstream fix. - continue; - } - - auto absolute_path = std::filesystem::absolute(test_case.ctx_filename).native(); - GenerateEpContextOnPluginPath(absolute_path.c_str(), test_case.embed_mode); - - Ort::SessionOptions session_options; - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); - Ort::Session session(*ort_env, absolute_path.c_str(), session_options); - RunModelWithSession(session); - } -} - -TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_multiple_devices) { - auto plugin_ep_device = GetOvCpuEpDevice(); - ASSERT_NE(plugin_ep_device, nullptr); - - std::vector multi_device_list(2, plugin_ep_device); // 2 copies of cpu device. - - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider_V2(*ort_env, multi_device_list, std::unordered_map{}); - Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); -} - -TEST_F(OrtEpLibraryOv, PluginEp_AppendV2_mixed_factory_devices_throw_exception) { - auto ep_devices = ort_env->GetEpDevices(); - std::vector matching_devices; - - for (const auto& device : ep_devices) { - std::string ep_name = device.EpName(); - if (ep_name.find(registration_name) != std::string::npos && - (ep_name == registration_name || ep_name == registration_name + ".AUTO")) { - matching_devices.push_back(device); - } - } - - ASSERT_GT(matching_devices.size(), 1) << "Expected more than one matching EP device"; - - EXPECT_THROW({ - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider_V2(*ort_env, matching_devices, std::unordered_map{}); - Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); }, Ort::Exception); -} From be346bb72042b31bbad8d78dbf5723308767ec98 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Sun, 31 Aug 2025 22:13:38 -0700 Subject: [PATCH 093/138] OVEP-CI updating version (#799) --- .github/workflows/reusable_linux_build_intel.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/reusable_linux_build_intel.yml b/.github/workflows/reusable_linux_build_intel.yml index 00859bb99d7f0..a9b718bb2e736 100644 --- a/.github/workflows/reusable_linux_build_intel.yml +++ b/.github/workflows/reusable_linux_build_intel.yml @@ -73,7 +73,7 @@ jobs: id-token: write steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python ${{ inputs.python_version }} uses: actions/setup-python@v5 @@ -81,7 +81,7 @@ jobs: python-version: ${{ inputs.python_version }} - name: Build Docker Image (${{ inputs.architecture }} / ${{ inputs.build_config }}) - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.5 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/${{ inputs.dockerfile_path }} @@ -101,7 +101,7 @@ jobs: # ------------- Update Step (CMake Generation) ------------- - name: Generate Build Files (CMake) (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: update_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} @@ -113,7 +113,7 @@ jobs: # ------------- Build Step (Compilation) ------------- - name: Build ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: build_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} @@ -126,7 +126,7 @@ jobs: - name: Test ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: test_step if: inputs.run_tests == true - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} From 2f1ad9d04d6c900d3c2749838f8196e720456e81 Mon Sep 17 00:00:00 2001 From: Jaswanth51 Date: Mon, 1 Sep 2025 00:29:54 -0700 Subject: [PATCH 094/138] Sync with Microsoft ONNX Runtime - 01/09/2025 (#801) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [CPU] Optimize GQA attention bias application for FP16 (#25871) ### Description When using attention bias input for GQA op with FP16, on the platforms that don't natively support FP16 math a cast to fp32 needs to be performed, and thus a temporary buffer needs to be created to store the fp32 values. The issue is that this temporary buffer was being allocated / deallocated inside of a loop for every token being processed. Refactored the implementation so that the allocation takes place only once. Phi model throughput increased by 15%. * Fixes for DynamicQuantizeMatMul and Attention3D tests (#25814) ### Description This change fixes correctness issues in two areas that were causing failures in onnxruntime_test_all: - DynamicQuantizeMatMul.WithConstantBInputs - AttentionTest.Attention3DDefault - AttentionTest.Attention3DWithPastAndPresentQkMatmul What was wrong and how it’s fixed 1) DynamicQuantizeMatMul.WithConstantBInputs - Root cause: The Kleidi dynamic quantization GEMM path could be selected even when the B scales contained values such as (zero, negative, or non-finite). That violates kernel assumptions and can lead to incorrect results. - Fix: In `onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc`, we now explicitly validate that all B scales are finite and strictly positive before enabling the Kleidi/MLAS dynamic path. If any scale is invalid, we disable that path. 2) Attention tests (Attention3DDefault, Attention3DWithPastAndPresentQkMatmul) - Root causes in `onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp`: - Incorrect handling of GEMM corner cases for alpha/beta and K==0 (e.g., not respecting C = beta*C when alpha==0 or K==0). - Unnecessary or premature fallbacks for small shapes. - Fixes: - Add early-outs for degenerate sizes: if M==0 or N==0, return handled. - Correctly implement alpha/beta semantics: --------- Signed-off-by: Jonathan Clohessy * Fix MoE CPP tests (#25877) This change adds skip test for QMoE CPU tests when running on TensorRT or CUDA EP. In the QMoE kernel there was a memory overwrite bug in the accumulate part, updated that and this fixed the python tests back * [c++] Eliminate dynamic initialization of static Ort::Global::api_ (#25741) ### Description Delay the call to `OrtGetApiBase()` until the first call to `Ort::GetApi()` so that `OrtGetApiBase()` is typically called after dynamic library loading. ### Motivation and Context When ORT_API_MANUAL_INIT is not defined (which is the default), the static `Ort::Global::api_` has a dynamic initializer that calls `OrtGetApiBase()->GetApi(ORT_API_VERSION)` This dynamic initialization can cause problems when it interacts with other global/static initialization. On Windows in particular, it can also cause deadlocks when used in a dynamic library if OrtGetApiBase()->GetApi() attempts to load any other libraries. * Replace the templated `Global::api_` with an inline static initialized to nullptr. * `Ort::GetApi()` now calls `detail::Global::GetApi()` which calls `detail::Global::DefaultInit()` if initialization is needed. * When `ORT_API_MANUAL_INIT` is defined, `DefaultInit()` returns nullptr, which will eventually cause the program to crash. The callers have violated the initialization contract by not calling one of the `Ort::InitApi` overloads. * When `ORT_API_MANUAL_INIT` is not defined, `DefaultInit()` uses a function-level static to compute the result of `OrtGetApiBase()->GetApi(ORT_API_VERSION)` once and return it. * `Ort::Global` has been replaced with a non-templated type and moved inside a `detail` namespace. Since the `Global` object was documented as being used internally, it is believed that these changes here are non-breaking, as they do not impact a public API. The public APIs, `Ort::InitApi()` and `Ort::InitApi(const OrtApi*)` remain unchanged. * Add `#pragma detect_mismatch` to surface issues with compilation units that disagree on how ORT_API_MANUAL_INIT is defined. (MSVC only.) --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * python GPU IO Bindings for NVIDIA (#25776) ### Description 1. A Small change to use the shared allocator in Python binding. 2. Remove the FP64 support from the EP. ### Motivation and Context The Python GPU IO binding is necessary for performance. The change will enable the shared allocator for GPU allocation. The FP64 was using the FP32 inference—aligned WRT TRT RTX support. --------- Co-authored-by: Gaurav Garg * [CANN] Add a `enable_cann_subgraph` feature parameter (#25867) ### Description Add a `enable_cann_subgraph` feature parameter. this parameter controls whether graph splitting is performed and can help quickly identify issues in certain scenarios. * [EP ABI] Add OpAttr_GetTensorAttributeAsOrtValue and replace the existing Node_GetTensorAttributeAsOrtValue (#25886) ### Description Replace `Node_GetTensorAttributeAsOrtValue` with `OpAttr_GetTensorAttributeAsOrtValue`. Change the API signature to make it one of the `OpAttr` interfaces instead of the `OrtNode` interface. The original API was added [here](https://github.com/microsoft/onnxruntime/pull/25566). * Language bindings for model compatibility API (#25878) ### Description This change builds on top of #25841 , and adds the scaffolding necessary to call into this API from C++ / C# / Python. ### Motivation and Context #25454 talks more about the broader notion of precompiled model compatibility. This change is directed at app developers whose apps may want to determine if a particular precompiled model (e.g. on a server somewhere) is compatible with the device where the application is running. There is functionality in `OrtEpFactory` for making this determination, which was exposed as a C API in #25841, and this change makes the API more broadly available in other languages. ### Testing and Validation Introduced new unit test cases across each language, and verified that the API was being called and returned the correct result for the default CPU EP. --------- Co-authored-by: Aditya Rastogi * [QNN-EP] Introduce Level1 Transformer into qnn.preprocess (#25883) ### Description - Introduce Level1 Transformer into qnn.preprocess to support various optimizations. ### Motivation and Context - This change brings in several useful optimizations such as `ConvBnFusion` and `ConstantFolding`, which are part of `TransformerLevel::Level1` and can benefit QNNEP. - The goal is to optimize the ONNX model before quantization by integrating these passes into the Python tooling workflow. * [QNN EP] Minor fix weight name missing when not valid QDQ node group (#25887) ### Description Minor fix weight name missing when not valid QDQ node group ### Motivation and Context Some quantized model failed QDQ node group validation, the weights then won't be folded as initializer. QNN EP failed to handle the dynamic weights here due to the transpose op input name look up. This change make sure we process the weights tensor before adding transposes. * Add custom ops library_path to EP metadata (#25830) ## Summary Adds EP metadata library path support to enable custom ops DLL registration with proper path resolution. ## Changes - Added `library_path` metadata key to EP metadata infrastructure - Pass resolved library path directly to `EpLibraryProviderBridge` constructor - Simplified implementation per reviewer feedback (removed virtual method complexity) - Added `#include ` for std::move compliance ## Purpose Enables downstream applications (like onnxruntime-genai) to resolve relative custom ops library paths using EP metadata, improving DLL registration reliability. ## Files Modified - `plugin_ep/ep_factory_provider_bridge.h` - `plugin_ep/ep_library.h` - `plugin_ep/ep_library_plugin.h` - `plugin_ep/ep_library_provider_bridge.cc` - `plugin_ep/ep_library_provider_bridge.h` - `utils.cc` * [OVEP] OpenVINO EP Features and bug-fixes for ORT-1.23 (#25884) ### Description This update introduces multiple improvements, fixes, and feature enhancements to the OpenVINO Execution Provider (OVEP) and related components in ONNX Runtime: #### Configuration & Properties - Updated load_config mapping to act as a passthrough to OpenVINO properties. - Added support for providing layout information to inputs/outputs in OpenVINO. #### Inference & Tensor Handling - Improved OVInferRequest::SetTensor to correctly handle cached binding shape mismatches. - Added support for self-detecting on-the-fly bfloat16 → float16 conversion. - Fixed issues with input ONNX models when used with shared execution contexts. #### Model Handling & Operator Support - Fixed model copying behavior for QDQ stripping. - Updated operator support status for OpenVINO 2025.2. #### Platform & Integration Fixes - Applied multiple PSU Lora fixes and related updates. - Resolved filename confusion issues with wrapped OVIRs in EPCtx. - Enabled memory-mapped native binaries for OpenVINO 2025.3. #### Quality & Maintenance - Addressed linting issues. - Fixed coverage gaps in OVEP. - Added a new test script for OpenVINO with ORT ABI integration. --------- Co-authored-by: Ankit Maheshkar Co-authored-by: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Co-authored-by: Klimenko, Mikhail Co-authored-by: sfatimar Co-authored-by: Garth Long Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: Eric Crawford Co-authored-by: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Co-authored-by: Vishnudas Thaniel S Co-authored-by: Javier Martinez * [java] Auto EP and compile model support (#25131) ### Description Java API for compile model and EP discovery APIs. Roughly equivalent to the C# version in #24604. cc: @skottmckay. I haven't quite got the CMake configured so the Java tests for the ep registration only run when the ONNX Runtime shared provider support is built, but everything else works. I expect that to be a quick fix, but I'm not sure in what conditions it should be built and how we should handle it so I don't know where/when to plumb it through. ### Motivation and Context API parity for Java. * Add error handling to extract_nuget_files.ps1 (#25866) ### Description 1. Check process exit code when running 7z.exe . Currently the errors were silently ignored. 2. Add snld20 flag to the 7z.exe commands, which is needed to be compatible with the latest 7z release. * [Fix] illegal memory access in GetInputIndices with optional inputs (#25881) ### Description Fix illegal memory access in GetInputIndices with optional inputs ### Motivation and Context When an input is optional, its ValueInfo may be nullptr. The current implementation directly calls InputValueInfo->GetName(), leading to illegal memory access. Update logic to skip optional inputs when valueInfo is nullptr . * Re-enable cpuinfo for ARM64EC (#25863) ### Description Re-enable cpuinfo for ARM64EC build and fix `CPUIDINFO_ARCH_ARM` so it is actually used. Patch cpuinfo to support vcpkg ARM64EC build. See https://github.com/pytorch/cpuinfo/pull/324. ### Motivation and Context Fix for workaround in #25831. --------- Signed-off-by: Jonathan Clohessy Co-authored-by: derdeljan-msft Co-authored-by: Jonathan Clohessy Co-authored-by: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Co-authored-by: Christopher Warrington Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Ishwar Raut Co-authored-by: Gaurav Garg Co-authored-by: Xinpeng Dou <15529241576@163.com> Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: adrastogi Co-authored-by: Aditya Rastogi Co-authored-by: qti-hungjuiw Co-authored-by: qti-yuduo Co-authored-by: Pradeep Sakhamoori Co-authored-by: Preetha Veeramalai Co-authored-by: Ankit Maheshkar Co-authored-by: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Co-authored-by: Klimenko, Mikhail Co-authored-by: sfatimar Co-authored-by: Garth Long Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: Eric Crawford Co-authored-by: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Co-authored-by: Vishnudas Thaniel S Co-authored-by: Javier Martinez Co-authored-by: Adam Pocock Co-authored-by: Changming Sun Co-authored-by: mingyue <131847423+mingyueliuh@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- cmake/CMakeLists.txt | 8 +- .../external/onnxruntime_external_deps.cmake | 61 +-- cmake/onnxruntime.cmake | 13 +- cmake/onnxruntime_common.cmake | 57 +-- cmake/onnxruntime_java.cmake | 4 +- cmake/onnxruntime_nodejs.cmake | 1 + cmake/onnxruntime_unittests.cmake | 4 + .../cpuinfo/patch_vcpkg_arm64ec_support.patch | 91 ++++ .../cpuinfo/patch_vcpkg_arm64ec_support.patch | 91 ++++ cmake/vcpkg-ports/cpuinfo/portfile.cmake | 1 + .../NativeMethods.shared.cs | 98 ++++ .../Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs | 40 ++ .../EpCompatibilityTests.cs | 49 ++ .../providers/cann/cann_provider_options.h | 2 + .../core/providers/utils/ort_graph_to_proto.h | 8 +- .../core/session/onnxruntime_c_api.h | 3 +- .../core/session/onnxruntime_cxx_api.h | 125 ++++- .../core/session/onnxruntime_cxx_inline.h | 20 + .../onnxruntime_ep_device_ep_metadata_keys.h | 5 +- .../main/java/ai/onnxruntime/OnnxRuntime.java | 18 +- .../java/ai/onnxruntime/OrtEnvironment.java | 82 ++- .../main/java/ai/onnxruntime/OrtEpDevice.java | 117 +++++ .../onnxruntime/{providers => }/OrtFlags.java | 4 +- .../ai/onnxruntime/OrtHardwareDevice.java | 156 ++++++ .../OrtModelCompilationOptions.java | 280 +++++++++++ .../main/java/ai/onnxruntime/OrtSession.java | 78 ++- .../src/main/java/ai/onnxruntime/OrtUtil.java | 51 +- .../ai/onnxruntime/providers/CoreMLFlags.java | 4 +- .../ai/onnxruntime/providers/NNAPIFlags.java | 4 +- java/src/main/native/OrtJniUtil.c | 30 ++ java/src/main/native/OrtJniUtil.h | 2 + .../main/native/ai_onnxruntime_OnnxRuntime.c | 13 + .../native/ai_onnxruntime_OrtEnvironment.c | 70 +++ .../main/native/ai_onnxruntime_OrtEpDevice.c | 82 +++ .../native/ai_onnxruntime_OrtHardwareDevice.c | 96 ++++ ...i_onnxruntime_OrtModelCompilationOptions.c | 193 ++++++++ ...ai_onnxruntime_OrtSession_SessionOptions.c | 53 +- .../java/ai/onnxruntime/CompileApiTest.java | 53 ++ .../java/ai/onnxruntime/EpDeviceTest.java | 123 +++++ js/node/src/inference_session_wrap.cc | 2 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 15 +- .../cpu/moe/moe_quantization_cpu.cc | 11 +- .../quantization/dynamic_quantize_matmul.cc | 17 +- .../core/common/cpuid_arch_definition.h | 2 +- onnxruntime/core/graph/abi_graph_types.h | 10 - onnxruntime/core/graph/ep_api_types.cc | 29 +- onnxruntime/core/graph/ep_api_types.h | 3 - .../core/graph/model_editor_api_types.h | 5 - .../core/mlas/lib/kleidiai/sgemm_kleidiai.cpp | 73 +-- .../providers/cann/cann_execution_provider.cc | 9 +- .../cann/cann_execution_provider_info.cc | 4 + .../cann/cann_execution_provider_info.h | 1 + .../providers/cann/cann_provider_factory.cc | 2 + .../nv_tensorrt_rtx/nv_execution_provider.cc | 151 +----- .../nv_tensorrt_rtx/nv_execution_provider.h | 2 - .../qnn/builder/opbuilder/conv_op_builder.cc | 14 +- .../shared_library/provider_ort_api_init.cc | 4 +- .../core/providers/vitisai/imp/global_api.cc | 6 +- onnxruntime/core/session/onnxruntime_c_api.cc | 38 +- onnxruntime/core/session/ort_apis.h | 2 +- .../plugin_ep/ep_factory_provider_bridge.cc | 7 + .../plugin_ep/ep_factory_provider_bridge.h | 15 +- .../core/session/plugin_ep/ep_library.h | 1 + .../plugin_ep/ep_library_provider_bridge.cc | 4 +- .../plugin_ep/ep_library_provider_bridge.h | 9 +- .../core/session/provider_bridge_ort.cc | 1 + onnxruntime/core/session/utils.cc | 5 +- .../python/onnxruntime_pybind_state.cc | 19 +- .../execution_providers/qnn/preprocess.py | 24 +- onnxruntime/test/autoep/library/ep_arena.h | 3 + onnxruntime/test/contrib_ops/moe_test.cc | 55 ++ .../test/framework/ep_compatibility_test.cc | 29 ++ .../test/platform/device_discovery_test.cc | 4 +- ...nnxruntime_test_python_ep_compatibility.py | 46 ++ ...me_test_python_nv_tensorrt_rtx_ep_tests.py | 468 ++++++++++++++++++ .../custom_op_library/custom_op_library.cc | 2 +- .../github/windows/extract_nuget_files.ps1 | 148 +++--- .../windows/extract_nuget_files_gpu.ps1 | 86 +++- 78 files changed, 3007 insertions(+), 509 deletions(-) create mode 100644 cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch create mode 100644 cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs create mode 100644 java/src/main/java/ai/onnxruntime/OrtEpDevice.java rename java/src/main/java/ai/onnxruntime/{providers => }/OrtFlags.java (88%) create mode 100644 java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java create mode 100644 java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java create mode 100644 java/src/main/native/ai_onnxruntime_OrtEpDevice.c create mode 100644 java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c create mode 100644 java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c create mode 100644 java/src/test/java/ai/onnxruntime/CompileApiTest.java create mode 100644 java/src/test/java/ai/onnxruntime/EpDeviceTest.java create mode 100644 onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py create mode 100644 onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 98548957d0b42..40e6a8da28e45 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1607,7 +1607,6 @@ if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux") endif() endif() - #Now the 'onnxruntime_EXTERNAL_LIBRARIES' variable should be sealed. It will be used in onnxruntime.cmake which will be included in the next. #The order of the following targets matters. Right depends on left. If target A appears before target B. Then A.cmake can not use variables defined in B.cmake. set(ONNXRUNTIME_CMAKE_FILES onnxruntime_flatbuffers onnxruntime_common onnxruntime_mlas onnxruntime_graph onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_providers onnxruntime_optimizer onnxruntime_session ${ONNXRUNTIME_EAGER_CMAKE_FILE_NAME}) @@ -1623,9 +1622,6 @@ if (onnxruntime_USE_WINML) list(APPEND ONNXRUNTIME_CMAKE_FILES winml) endif() # if (onnxruntime_USE_WINML) -if (onnxruntime_BUILD_APPLE_FRAMEWORK AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS|visionOS|tvOS") - message(FATAL_ERROR "onnxruntime_BUILD_APPLE_FRAMEWORK can only be enabled for macOS or iOS or visionOS or tvOS.") -endif() list(APPEND ONNXRUNTIME_CMAKE_FILES onnxruntime) if (onnxruntime_BUILD_JAVA) @@ -1690,8 +1686,8 @@ if (WIN32 AND NOT GDK_PLATFORM AND NOT CMAKE_CROSSCOMPILING) endif() endif() -foreach(target_name ${ONNXRUNTIME_CMAKE_FILES}) - include(${target_name}.cmake) +foreach(onnxruntime_cmake_file ${ONNXRUNTIME_CMAKE_FILES}) + include(${onnxruntime_cmake_file}.cmake) endforeach() if (UNIX) option(BUILD_PKGCONFIG_FILES "Build and install pkg-config files" ON) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 3095968795d1a..827be3e6dea2a 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -313,41 +313,32 @@ onnxruntime_fetchcontent_makeavailable(nlohmann_json) if (onnxruntime_ENABLE_CPUINFO) # Adding pytorch CPU info library # TODO!! need a better way to find out the supported architectures - list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) + set(CPUINFO_SUPPORTED FALSE) if (APPLE) + list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) if (CMAKE_OSX_ARCHITECTURES_LEN LESS_EQUAL 1) set(CPUINFO_SUPPORTED TRUE) - elseif (onnxruntime_BUILD_APPLE_FRAMEWORK) - # We stitch multiple static libraries together when onnxruntime_BUILD_APPLE_FRAMEWORK is true, - # but that would not work for universal static libraries - message(FATAL_ERROR "universal binary is not supported for apple framework") - endif() - else() - # if xnnpack is enabled in a wasm build it needs clog from cpuinfo, but we won't internally use cpuinfo - # so we don't set CPUINFO_SUPPORTED in the CXX flags below. - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_USE_XNNPACK) - set(CPUINFO_SUPPORTED FALSE) else() + message(WARNING "cpuinfo is not supported when CMAKE_OSX_ARCHITECTURES has more than one value.") + endif() + elseif (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # if xnnpack is enabled in a wasm build it needs clog from cpuinfo, but we won't internally use cpuinfo. + if (onnxruntime_USE_XNNPACK) set(CPUINFO_SUPPORTED TRUE) endif() - if (WIN32) - # There's an error when linking with cpuinfo on arm64ec with a vcpkg build (--use_vcpkg). - # TODO Fix it and then re-enable cpuinfo on arm64ec. - if (onnxruntime_target_platform STREQUAL "ARM64EC") - set(CPUINFO_SUPPORTED FALSE) - else() - set(CPUINFO_SUPPORTED TRUE) - endif() - elseif (NOT ${onnxruntime_target_platform} MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") - message(WARNING - "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo. " - "cpuinfo not included." - ) - set(CPUINFO_SUPPORTED FALSE) + elseif (WIN32) + set(CPUINFO_SUPPORTED TRUE) + else() + if (onnxruntime_target_platform MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") + set(CPUINFO_SUPPORTED TRUE) + else() + message(WARNING "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo.") endif() endif() -else() - set(CPUINFO_SUPPORTED FALSE) + + if(NOT CPUINFO_SUPPORTED) + message(WARNING "onnxruntime_ENABLE_CPUINFO was set but cpuinfo is not supported.") + endif() endif() if (CPUINFO_SUPPORTED) @@ -358,23 +349,26 @@ if (CPUINFO_SUPPORTED) # if this is a wasm build with xnnpack (only type of wasm build where cpuinfo is involved) # we do not use cpuinfo in ORT code, so don't define CPUINFO_SUPPORTED. - if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - string(APPEND CMAKE_CXX_FLAGS " -DCPUINFO_SUPPORTED") + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_USE_XNNPACK) + else() + add_compile_definitions(CPUINFO_SUPPORTED) endif() - set(CPUINFO_BUILD_TOOLS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE INTERNAL "") if (onnxruntime_target_platform STREQUAL "ARM64EC" OR onnxruntime_target_platform STREQUAL "ARM64") - message(STATUS "Applying a patch for Windows ARM64/ARM64EC in cpuinfo") + message(STATUS "Applying patches for Windows ARM64/ARM64EC in cpuinfo") onnxruntime_fetchcontent_declare( pytorch_cpuinfo URL ${DEP_URL_pytorch_cpuinfo} URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} EXCLUDE_FROM_ALL - PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch + PATCH_COMMAND + ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch && + # https://github.com/pytorch/cpuinfo/pull/324 + ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch FIND_PACKAGE_ARGS NAMES cpuinfo ) else() @@ -584,8 +578,7 @@ endif() set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES_XNNPACK} ${WIL_TARGET} nlohmann_json::nlohmann_json onnx onnx_proto ${PROTOBUF_LIB} re2::re2 Boost::mp11 safeint_interface - flatbuffers::flatbuffers ${GSL_TARGET} ${ABSEIL_LIBS} date::date - ${ONNXRUNTIME_CLOG_TARGET_NAME} Eigen3::Eigen) + flatbuffers::flatbuffers ${GSL_TARGET} ${ABSEIL_LIBS} date::date Eigen3::Eigen) # The source code of onnx_proto is generated, we must build this lib first before starting to compile the other source code that uses ONNX protobuf types. # The other libs do not have the problem. All the sources are already there. We can compile them in any order. diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 010696a61022c..e1d98109208d4 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -350,8 +350,19 @@ if (winml_is_inbox) endif() endif() -# Assemble the Apple static framework (iOS and macOS) +# Assemble the Apple static framework if(onnxruntime_BUILD_APPLE_FRAMEWORK) + if (NOT CMAKE_SYSTEM_NAME MATCHES "Darwin|iOS|visionOS|tvOS") + message(FATAL_ERROR "onnxruntime_BUILD_APPLE_FRAMEWORK can only be enabled for macOS or iOS or visionOS or tvOS.") + endif() + + list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) + if (CMAKE_OSX_ARCHITECTURES_LEN GREATER 1) + # We stitch multiple static libraries together when onnxruntime_BUILD_APPLE_FRAMEWORK is true, + # but that would not work for universal static libraries + message(FATAL_ERROR "universal binary is not supported for apple framework") + endif() + # when building for mac catalyst, the CMAKE_OSX_SYSROOT is set to MacOSX as well, to avoid duplication, # we specify as `-macabi` in the name of the output static apple framework directory. if (PLATFORM_NAME STREQUAL "macabi") diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index d927489372e7c..0218994e537a0 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -194,59 +194,10 @@ if(APPLE) target_link_libraries(onnxruntime_common PRIVATE "-framework Foundation") endif() -if(MSVC) - if(onnxruntime_target_platform STREQUAL "ARM64") - set(ARM64 TRUE) - elseif (onnxruntime_target_platform STREQUAL "ARM") - set(ARM TRUE) - elseif(onnxruntime_target_platform STREQUAL "x64") - set(X64 TRUE) - elseif(onnxruntime_target_platform STREQUAL "x86") - set(X86 TRUE) - endif() -elseif(APPLE) - if(CMAKE_OSX_ARCHITECTURES_LEN LESS_EQUAL 1) - set(X64 TRUE) - endif() -elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") - set(ARM TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") - set(ARM64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") - set(X86 TRUE) - endif() - else() - execute_process( - COMMAND ${CMAKE_C_COMPILER} -dumpmachine - OUTPUT_VARIABLE dumpmachine_output - ERROR_QUIET - ) - if(dumpmachine_output MATCHES "^arm64.*") - set(ARM64 TRUE) - elseif(dumpmachine_output MATCHES "^arm.*") - set(ARM TRUE) - elseif(dumpmachine_output MATCHES "^aarch64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") - set(RISCV64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - set(X86 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - set(X86_64 TRUE) - endif() - endif() -endif() - -if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) - # Link cpuinfo if supported - if (CPUINFO_SUPPORTED) - onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo ${ONNXRUNTIME_CLOG_TARGET_NAME}) - endif() +if(CPUINFO_SUPPORTED) + # Link cpuinfo if supported + onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo) endif() if (NOT onnxruntime_BUILD_SHARED_LIB) diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index 6b638b3e5d8bc..7da63b523be70 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -159,7 +159,7 @@ if (WIN32) if(NOT onnxruntime_ENABLE_STATIC_ANALYSIS) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_JNI_DIR}/$) - if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)) + if (TARGET onnxruntime_providers_shared) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() if (onnxruntime_USE_CUDA) @@ -207,7 +207,7 @@ if (WIN32) else() add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_JNI_DIR}/$) - if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)) + if (TARGET onnxruntime_providers_shared) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() if (onnxruntime_USE_CUDA) diff --git a/cmake/onnxruntime_nodejs.cmake b/cmake/onnxruntime_nodejs.cmake index b28bda6c94276..cce0810c5bbe8 100644 --- a/cmake/onnxruntime_nodejs.cmake +++ b/cmake/onnxruntime_nodejs.cmake @@ -10,6 +10,7 @@ include(node_helper.cmake) # setup ARCH if (APPLE) + list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) if (CMAKE_OSX_ARCHITECTURES_LEN GREATER 1) message(FATAL_ERROR "CMake.js does not support multi-architecture for macOS") endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 6847db64004ca..b31849440c426 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1640,6 +1640,10 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") add_custom_command(TARGET onnxruntime_providers_qnn POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} ${JAVA_NATIVE_TEST_DIR}) endif() + if (WIN32) + set(EXAMPLE_PLUGIN_EP_DST_FILE_NAME $,$,$>) + add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_NATIVE_TEST_DIR}/${EXAMPLE_PLUGIN_EP_DST_FILE_NAME}) + endif() # delegate to gradle's test runner diff --git a/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch new file mode 100644 index 0000000000000..af0f039b6c2a3 --- /dev/null +++ b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch @@ -0,0 +1,91 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index aedc983..dab589e 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am + ENDIF() + IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") + SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") ++ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") ++ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. ++ IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") ++ SET(CPUINFO_TARGET_PROCESSOR "x86") ++ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") ++ SET(CPUINFO_TARGET_PROCESSOR "x86_64") ++ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID MATCHES "^(ARM64|ARM64EC)$") ++ SET(CPUINFO_TARGET_PROCESSOR "arm64") ++ ELSE() ++ MESSAGE(FATAL_ERROR "Unsupported MSVC compiler architecture ID \"${CMAKE_C_COMPILER_ARCHITECTURE_ID}\"") ++ ENDIF() + ELSEIF(CMAKE_GENERATOR MATCHES "^Visual Studio " AND CMAKE_VS_PLATFORM_NAME) + IF(CMAKE_VS_PLATFORM_NAME STREQUAL "Win32") + SET(CPUINFO_TARGET_PROCESSOR "x86") +@@ -88,7 +99,7 @@ ENDIF() + + # ---[ Build flags + SET(CPUINFO_SUPPORTED_PLATFORM TRUE) +-IF(NOT CMAKE_SYSTEM_PROCESSOR) ++IF(NOT CPUINFO_TARGET_PROCESSOR) + IF(NOT IOS) + MESSAGE(WARNING + "Target processor architecture is not specified. " +@@ -201,12 +212,12 @@ IF(CPUINFO_SUPPORTED_PLATFORM) + src/arm/linux/chipset.c + src/arm/linux/midr.c + src/arm/linux/hwcap.c) +- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]") ++ IF(CPUINFO_TARGET_PROCESSOR MATCHES "^armv[5-8]") + LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch32-isa.c) + IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND ANDROID_ABI STREQUAL "armeabi") + SET_SOURCE_FILES_PROPERTIES(src/arm/linux/aarch32-isa.c PROPERTIES COMPILE_FLAGS -marm) + ENDIF() +- ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)$") ++ ELSEIF(CPUINFO_TARGET_PROCESSOR MATCHES "^(aarch64|arm64)$") + LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch64-isa.c) + ENDIF() + ELSEIF(IS_APPLE_OS AND CPUINFO_TARGET_PROCESSOR MATCHES "arm64.*") +@@ -395,7 +406,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PRIVATE _GNU_SOURCE=1) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a)$") + ADD_EXECUTABLE(atm7029b-tablet-test test/mock/atm7029b-tablet.cc) + TARGET_INCLUDE_DIRECTORIES(atm7029b-tablet-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(atm7029b-tablet-test PRIVATE cpuinfo_mock gtest) +@@ -577,7 +588,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + ADD_TEST(NAME xperia-sl-test COMMAND xperia-sl-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") + ADD_EXECUTABLE(alcatel-revvl-test test/mock/alcatel-revvl.cc) + TARGET_INCLUDE_DIRECTORIES(alcatel-revvl-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(alcatel-revvl-test PRIVATE cpuinfo_mock gtest) +@@ -774,7 +785,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + ADD_TEST(NAME xperia-c4-dual-test COMMAND xperia-c4-dual-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(i686|x86_64)$") + ADD_EXECUTABLE(alldocube-iwork8-test test/mock/alldocube-iwork8.cc) + TARGET_INCLUDE_DIRECTORIES(alldocube-iwork8-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(alldocube-iwork8-test PRIVATE cpuinfo_mock gtest) +@@ -831,7 +842,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_UNIT_TESTS) + ADD_TEST(NAME brand-string-test COMMAND brand-string-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") + ADD_LIBRARY(android_properties_interface STATIC test/name/android-properties-interface.c) + CPUINFO_TARGET_ENABLE_C99(android_properties_interface) + CPUINFO_TARGET_RUNTIME_LIBRARY(android_properties_interface) +@@ -879,7 +890,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_TOOLS) + TARGET_LINK_LIBRARIES(cache-info PRIVATE cpuinfo) + INSTALL(TARGETS cache-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + +- IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") + ADD_EXECUTABLE(auxv-dump tools/auxv-dump.c) + CPUINFO_TARGET_ENABLE_C99(auxv-dump) + CPUINFO_TARGET_RUNTIME_LIBRARY(auxv-dump) diff --git a/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch new file mode 100644 index 0000000000000..af0f039b6c2a3 --- /dev/null +++ b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch @@ -0,0 +1,91 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index aedc983..dab589e 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am + ENDIF() + IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") + SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") ++ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") ++ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. ++ IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") ++ SET(CPUINFO_TARGET_PROCESSOR "x86") ++ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") ++ SET(CPUINFO_TARGET_PROCESSOR "x86_64") ++ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID MATCHES "^(ARM64|ARM64EC)$") ++ SET(CPUINFO_TARGET_PROCESSOR "arm64") ++ ELSE() ++ MESSAGE(FATAL_ERROR "Unsupported MSVC compiler architecture ID \"${CMAKE_C_COMPILER_ARCHITECTURE_ID}\"") ++ ENDIF() + ELSEIF(CMAKE_GENERATOR MATCHES "^Visual Studio " AND CMAKE_VS_PLATFORM_NAME) + IF(CMAKE_VS_PLATFORM_NAME STREQUAL "Win32") + SET(CPUINFO_TARGET_PROCESSOR "x86") +@@ -88,7 +99,7 @@ ENDIF() + + # ---[ Build flags + SET(CPUINFO_SUPPORTED_PLATFORM TRUE) +-IF(NOT CMAKE_SYSTEM_PROCESSOR) ++IF(NOT CPUINFO_TARGET_PROCESSOR) + IF(NOT IOS) + MESSAGE(WARNING + "Target processor architecture is not specified. " +@@ -201,12 +212,12 @@ IF(CPUINFO_SUPPORTED_PLATFORM) + src/arm/linux/chipset.c + src/arm/linux/midr.c + src/arm/linux/hwcap.c) +- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]") ++ IF(CPUINFO_TARGET_PROCESSOR MATCHES "^armv[5-8]") + LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch32-isa.c) + IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND ANDROID_ABI STREQUAL "armeabi") + SET_SOURCE_FILES_PROPERTIES(src/arm/linux/aarch32-isa.c PROPERTIES COMPILE_FLAGS -marm) + ENDIF() +- ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)$") ++ ELSEIF(CPUINFO_TARGET_PROCESSOR MATCHES "^(aarch64|arm64)$") + LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch64-isa.c) + ENDIF() + ELSEIF(IS_APPLE_OS AND CPUINFO_TARGET_PROCESSOR MATCHES "arm64.*") +@@ -395,7 +406,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PRIVATE _GNU_SOURCE=1) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a)$") + ADD_EXECUTABLE(atm7029b-tablet-test test/mock/atm7029b-tablet.cc) + TARGET_INCLUDE_DIRECTORIES(atm7029b-tablet-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(atm7029b-tablet-test PRIVATE cpuinfo_mock gtest) +@@ -577,7 +588,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + ADD_TEST(NAME xperia-sl-test COMMAND xperia-sl-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") + ADD_EXECUTABLE(alcatel-revvl-test test/mock/alcatel-revvl.cc) + TARGET_INCLUDE_DIRECTORIES(alcatel-revvl-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(alcatel-revvl-test PRIVATE cpuinfo_mock gtest) +@@ -774,7 +785,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + ADD_TEST(NAME xperia-c4-dual-test COMMAND xperia-c4-dual-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(i686|x86_64)$") + ADD_EXECUTABLE(alldocube-iwork8-test test/mock/alldocube-iwork8.cc) + TARGET_INCLUDE_DIRECTORIES(alldocube-iwork8-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(alldocube-iwork8-test PRIVATE cpuinfo_mock gtest) +@@ -831,7 +842,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_UNIT_TESTS) + ADD_TEST(NAME brand-string-test COMMAND brand-string-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") + ADD_LIBRARY(android_properties_interface STATIC test/name/android-properties-interface.c) + CPUINFO_TARGET_ENABLE_C99(android_properties_interface) + CPUINFO_TARGET_RUNTIME_LIBRARY(android_properties_interface) +@@ -879,7 +890,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_TOOLS) + TARGET_LINK_LIBRARIES(cache-info PRIVATE cpuinfo) + INSTALL(TARGETS cache-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + +- IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") + ADD_EXECUTABLE(auxv-dump tools/auxv-dump.c) + CPUINFO_TARGET_ENABLE_C99(auxv-dump) + CPUINFO_TARGET_RUNTIME_LIBRARY(auxv-dump) diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index 3fcf76b7adafc..eeb0007195ca3 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -11,6 +11,7 @@ vcpkg_from_github( HEAD_REF master PATCHES patch_cpuinfo_h_for_arm64ec.patch + patch_vcpkg_arm64ec_support.patch # https://github.com/pytorch/cpuinfo/pull/324 ) vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 8cca2b42e987a..3c92400715740 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -368,6 +368,88 @@ public struct OrtApi public IntPtr EpDevice_Device; public IntPtr GetEpApi; public IntPtr GetTensorSizeInBytes; + + public IntPtr AllocatorGetStats; + + public IntPtr CreateMemoryInfo_V2; + public IntPtr MemoryInfoGetDeviceMemType; + public IntPtr MemoryInfoGetVendorId; + + public IntPtr ValueInfo_GetValueProducer; + public IntPtr ValueInfo_GetValueNumConsumers; + public IntPtr ValueInfo_GetValueConsumers; + public IntPtr ValueInfo_GetInitializerValue; + public IntPtr ValueInfo_GetExternalInitializerInfo; + public IntPtr ValueInfo_IsRequiredGraphInput; + public IntPtr ValueInfo_IsOptionalGraphInput; + public IntPtr ValueInfo_IsGraphOutput; + public IntPtr ValueInfo_IsConstantInitializer; + public IntPtr ValueInfo_IsFromOuterScope; + public IntPtr Graph_GetName; + public IntPtr Graph_GetModelPath; + public IntPtr Graph_GetOnnxIRVersion; + public IntPtr Graph_GetNumOperatorSets; + public IntPtr Graph_GetOperatorSets; + public IntPtr Graph_GetNumInputs; + public IntPtr Graph_GetInputs; + public IntPtr Graph_GetNumOutputs; + public IntPtr Graph_GetOutputs; + public IntPtr Graph_GetNumInitializers; + public IntPtr Graph_GetInitializers; + public IntPtr Graph_GetNumNodes; + public IntPtr Graph_GetNodes; + public IntPtr Graph_GetParentNode; + public IntPtr Graph_GetGraphView; + public IntPtr Node_GetId; + public IntPtr Node_GetName; + public IntPtr Node_GetOperatorType; + public IntPtr Node_GetDomain; + public IntPtr Node_GetSinceVersion; + public IntPtr Node_GetNumInputs; + public IntPtr Node_GetInputs; + public IntPtr Node_GetNumOutputs; + public IntPtr Node_GetOutputs; + public IntPtr Node_GetNumImplicitInputs; + public IntPtr Node_GetImplicitInputs; + public IntPtr Node_GetNumAttributes; + public IntPtr Node_GetAttributes; + public IntPtr Node_GetAttributeByName; + public IntPtr Node_GetTensorAttributeAsOrtValue; + public IntPtr OpAttr_GetType; + public IntPtr OpAttr_GetName; + public IntPtr Node_GetNumSubgraphs; + public IntPtr Node_GetSubgraphs; + public IntPtr Node_GetGraph; + public IntPtr Node_GetEpName; + public IntPtr ReleaseExternalInitializerInfo; + public IntPtr ExternalInitializerInfo_GetFilePath; + public IntPtr ExternalInitializerInfo_GetFileOffset; + public IntPtr ExternalInitializerInfo_GetByteSize; + + public IntPtr GetRunConfigEntry; + + public IntPtr EpDevice_MemoryInfo; + + public IntPtr CreateSharedAllocator; + public IntPtr GetSharedAllocator; + public IntPtr ReleaseSharedAllocator; + + public IntPtr GetTensorData; + + public IntPtr GetSessionOptionsConfigEntries; + + public IntPtr SessionGetMemoryInfoForInputs; + public IntPtr SessionGetMemoryInfoForOutputs; + public IntPtr SessionGetEpDeviceForInputs; + + public IntPtr CreateSyncStreamForEpDevice; + public IntPtr SyncStream_GetHandle; + public IntPtr ReleaseSyncStream; + + public IntPtr CopyTensors; + + public IntPtr Graph_GetModelMetadata; + public IntPtr GetModelCompatibilityForEpDevices; } internal static class NativeMethods @@ -704,6 +786,10 @@ static NativeMethods() (DSessionOptionsSetEpSelectionPolicyDelegate)Marshal.GetDelegateForFunctionPointer( api_.SessionOptionsSetEpSelectionPolicyDelegate, typeof(DSessionOptionsSetEpSelectionPolicyDelegate)); + + OrtGetModelCompatibilityForEpDevices = (DOrtGetModelCompatibilityForEpDevices)Marshal.GetDelegateForFunctionPointer( + api_.GetModelCompatibilityForEpDevices, + typeof(DOrtGetModelCompatibilityForEpDevices)); } internal class NativeLib @@ -2456,6 +2542,18 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public static DOrtGetEpDevices OrtGetEpDevices; + /// + /// Validate compiled model compatibility for the provided EP devices. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetModelCompatibilityForEpDevices( + IntPtr[] /* const OrtEpDevice* const* */ ep_devices, + UIntPtr /* size_t */ num_ep_devices, + byte[] /* const char* */ compatibility_info, + out int /* OrtCompiledModelCompatibility */ out_status); + + public static DOrtGetModelCompatibilityForEpDevices OrtGetModelCompatibilityForEpDevices; + /// /// Add execution provider devices to the session options. /// Priority is based on the order of the OrtEpDevice instances. Highest priority first. diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index 5c70808b82be1..052d5899b52c0 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -7,6 +7,21 @@ namespace Microsoft.ML.OnnxRuntime { + /// + /// Represents the compatibility status of a pre-compiled model with one or more execution provider devices. + /// + /// + /// This enum is used to determine whether a pre-compiled model can be used with specific execution providers + /// and devices, or if recompilation is needed. + /// + public enum OrtCompiledModelCompatibility + { + EP_NOT_APPLICABLE = 0, + EP_SUPPORTED_OPTIMAL = 1, + EP_SUPPORTED_PREFER_RECOMPILATION = 2, + EP_UNSUPPORTED = 3, + } + /// /// Delegate for logging function callback. /// Supply your function and register it with the environment to receive logging callbacks via @@ -361,6 +376,31 @@ public string[] GetAvailableProviders() } } + /// + /// Validate a compiled model's compatibility information for one or more EP devices. + /// + /// The list of EP devices to validate against. + /// The compatibility string from the precompiled model to validate. + /// OrtCompiledModelCompatibility enum value denoting the compatibility status + public OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( + IReadOnlyList epDevices, string compatibilityInfo) + { + if (epDevices == null || epDevices.Count == 0) + throw new ArgumentException("epDevices must be non-empty", nameof(epDevices)); + + var devicePtrs = new IntPtr[epDevices.Count]; + for (int i = 0; i < epDevices.Count; ++i) + { + devicePtrs[i] = epDevices[i].Handle; + } + + var infoUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(compatibilityInfo); + NativeApiStatus.VerifySuccess( + NativeMethods.OrtGetModelCompatibilityForEpDevices( + devicePtrs, (UIntPtr)devicePtrs.Length, infoUtf8, out int status)); + return (OrtCompiledModelCompatibility)status; + } + /// /// Get/Set log level property of OrtEnv instance diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs new file mode 100644 index 0000000000000..103fe5bc10106 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// not supported on mobile platforms +#if !(ANDROID || IOS) + +namespace Microsoft.ML.OnnxRuntime.Tests; + +using System; +using System.Linq; +using Xunit; +using System.Collections.Generic; + +public class EpCompatibilityTests +{ + private readonly OrtEnv ortEnvInstance = OrtEnv.Instance(); + + private IReadOnlyList GetDevices() + { + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotNull(epDevices); + Assert.NotEmpty(epDevices); + return epDevices; + } + + [Fact] + public void GetEpCompatibility_InvalidArgs() + { + Assert.Throws(() => ortEnvInstance.GetModelCompatibilityForEpDevices(null, "info")); + Assert.Throws(() => ortEnvInstance.GetModelCompatibilityForEpDevices(new List(), "info")); + } + + [Fact] + public void GetEpCompatibility_SingleDeviceCpuProvider() + { + var devices = GetDevices(); + var someInfo = "arbitrary-compat-string"; + + // Use CPU device + var cpu = devices.First(d => d.EpName == "CPUExecutionProvider"); + Assert.NotNull(cpu); + var selected = new List { cpu }; + var status = ortEnvInstance.GetModelCompatibilityForEpDevices(selected, someInfo); + + // CPU defaults to not applicable in this scenario + Assert.Equal(OrtCompiledModelCompatibility.EP_NOT_APPLICABLE, status); + } +} +#endif diff --git a/include/onnxruntime/core/providers/cann/cann_provider_options.h b/include/onnxruntime/core/providers/cann/cann_provider_options.h index 51b423e68110a..4b33ee77a892e 100644 --- a/include/onnxruntime/core/providers/cann/cann_provider_options.h +++ b/include/onnxruntime/core/providers/cann/cann_provider_options.h @@ -15,6 +15,8 @@ struct OrtCANNProviderOptions { onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena int enable_cann_graph; // Flag indicating if prioritizing the use of // CANN's graph-running capabilities + int enable_cann_subgraph; // Flag indicating whether to generate subgraph + // automaticly int dump_graphs; // Flag indicating if dumping graphs int dump_om_model; // Flag indicating if dumping om model std::string precision_mode; // Operator Precision Mode diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 21aa797ce16eb..28ce4439fdc7e 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::GraphProto& graph_proto, @@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, } onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); } } @@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { const OrtApi& ort_api = Ort::GetApi(); const char* attr_name = nullptr; @@ -766,7 +766,7 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or // TensorProto as an attribute value doesn't require a name. OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value)); Ort::Value tensor(ort_value); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9ae6174817b7c..f137d88e5fb8a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6079,7 +6079,6 @@ struct OrtApi { /** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue. * - * \param[in] node The OrtNode instance. * \param[in] attribute The OrtOpAttr instance. * \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue. Must be freed with OrtApi::ReleaseValue. @@ -6088,7 +6087,7 @@ struct OrtApi { * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + ORT_API2_STATUS(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index c39e27088e8bc..13675ab447ab1 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -79,22 +79,19 @@ struct Exception : std::exception { throw Ort::Exception(string, code) #endif -// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, -// it's in a template so that we can define a global variable in a header and make -// it transparent to the users of the API. -template -struct Global { - static const OrtApi* api_; -}; - -// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. -template #ifdef ORT_API_MANUAL_INIT -const OrtApi* Global::api_{}; -inline void InitApi() noexcept { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } - -// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is -// required by C++ APIs. +// If the macro ORT_API_MANUAL_INIT is defined, no static initialization +// will be performed. Instead, users must call InitApi() before using the +// ORT C++ APIs.. +// +// InitApi() sets the global API object using the default initialization +// logic. Users call this to initialize the ORT C++ APIs at a time that +// makes sense in their program. +inline void InitApi() noexcept; + +// InitApi(const OrtApi*) is used by custom operator libraries that are not +// linked to onnxruntime. It sets the global API object, which is required +// by the ORT C++ APIs. // // Example mycustomop.cc: // @@ -107,22 +104,88 @@ inline void InitApi() noexcept { Global::api_ = OrtGetApiBase()->GetApi(OR // // ... // } // -inline void InitApi(const OrtApi* api) noexcept { Global::api_ = api; } -#else -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers. -// Please define ORT_API_MANUAL_INIT if it conerns you. -#pragma warning(disable : 26426) +inline void InitApi(const OrtApi* api) noexcept; #endif -const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) + +namespace detail { +// This is used internally by the C++ API. This class holds the global +// variable that points to the OrtApi. +struct Global { + static const OrtApi* Api(const OrtApi* newValue = nullptr) noexcept { + // This block-level static will be initialized once when this function is + // first executed, delaying the call to DefaultInit() until it is first needed. + // + // When ORT_API_MANUAL_INIT is not defined, DefaultInit() calls + // OrtGetApiBase()->GetApi(), which may result in a shared library being + // loaded. + // + // Using a block-level static instead of a class-level static helps + // avoid issues with static initialization order and dynamic libraries + // loading other dynamic libraries. + // + // This makes it safe to include the C++ API headers in a shared library + // that is delay loaded or delay loads its dependencies. + // + // This DOES NOT make it safe to _use_ arbitrary ORT C++ APIs when + // initializing static members, however. + static const OrtApi* api = DefaultInit(); + + if (newValue) { + api = newValue; + } + + return api; + } + + private: + // Has different definitions based on ORT_API_MANUAL_INIT + static const OrtApi* DefaultInit() noexcept; + +#ifdef ORT_API_MANUAL_INIT + // Public APIs to set the OrtApi* to use. + friend void ::Ort::InitApi() noexcept; + friend void ::Ort::InitApi(const OrtApi*) noexcept; #endif +}; +} // namespace detail + +#ifdef ORT_API_MANUAL_INIT + +// See comments on declaration above for usage. +inline void InitApi(const OrtApi* api) noexcept { detail::Global::Api(api); } +inline void InitApi() noexcept { InitApi(OrtGetApiBase()->GetApi(ORT_API_VERSION)); } + +#ifdef _MSC_VER +// If you get a linker error about a mismatch here, you are trying to +// link two compilation units that have different definitions for +// ORT_API_MANUAL_INIT together. All compilation units must agree on the +// definition of ORT_API_MANUAL_INIT. +#pragma detect_mismatch("ORT_API_MANUAL_INIT", "enabled") #endif +inline const OrtApi* detail::Global::DefaultInit() noexcept { + // When ORT_API_MANUAL_INIT is defined, there's no default init that can + // be done. + return nullptr; +} + +#else // ORT_API_MANUAL_INIT + +#ifdef _MSC_VER +// If you get a linker error about a mismatch here, you are trying to link +// two compilation units that have different definitions for +// ORT_API_MANUAL_INIT together. All compilation units must agree on the +// definition of ORT_API_MANUAL_INIT. +#pragma detect_mismatch("ORT_API_MANUAL_INIT", "disabled") +#endif + +inline const OrtApi* detail::Global::DefaultInit() noexcept { + return OrtGetApiBase()->GetApi(ORT_API_VERSION); +} +#endif // ORT_API_MANUAL_INIT + /// This returns a reference to the ORT C API. -inline const OrtApi& GetApi() noexcept { return *Global::api_; } +inline const OrtApi& GetApi() noexcept { return *detail::Global::Api(); } /// /// This function returns the onnxruntime version string @@ -1013,6 +1076,16 @@ struct EpDevice : detail::EpDeviceImpl { ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {}); }; +/** \brief Validate a compiled model's compatibility for one or more EP devices. + * + * Throws on error. Returns the resulting compatibility status. + * /// \param ep_devices The EP devices to check compatibility against. + * /// \param compatibility_info The compatibility string from the precompiled model to validate. + */ +OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( + const std::vector& ep_devices, + const char* compatibility_info); + /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index d0089726812a3..05c86ae4e0c58 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -859,6 +859,26 @@ inline void CustomOpDomain::Add(const OrtCustomOp* op) { ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); } +inline OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( + const std::vector& ep_devices, + const char* compatibility_info) { + if (ep_devices.empty()) { + ORT_CXX_API_THROW("ep_devices is empty", ORT_INVALID_ARGUMENT); + } + + std::vector ptrs; + ptrs.reserve(ep_devices.size()); + for (const auto& d : ep_devices) ptrs.push_back(d); + + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + ThrowOnError(GetApi().GetModelCompatibilityForEpDevices( + reinterpret_cast(ptrs.data()), + ptrs.size(), + compatibility_info, + &status)); + return status; +} + inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string& adapter_path, OrtAllocator* allocator) { OrtLoraAdapter* p; diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index 672103bedc437..bbd6a43bb7a41 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -12,4 +12,7 @@ static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; // Prefix for execution provider compatibility information stored in model metadata. // Used when generating EP context models to store compatibility strings for each EP. // Full key format: "ep_compatibility_info." -static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; \ No newline at end of file +static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; + +// Key for the execution provider library path (for dynamically loaded EPs) +static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path"; diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index 97423ffb37251..3bb61698f5da7 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -42,6 +42,8 @@ final class OnnxRuntime { private static final int ORT_API_VERSION_13 = 13; // Post 1.13 builds of the ORT API private static final int ORT_API_VERSION_14 = 14; + // Post 1.22 builds of the ORT API + private static final int ORT_API_VERSION_23 = 23; // The initial release of the ORT training API. private static final int ORT_TRAINING_API_VERSION_1 = 1; @@ -103,6 +105,9 @@ final class OnnxRuntime { /** The Training API handle. */ static long ortTrainingApiHandle; + /** The Compile API handle. */ + static long ortCompileApiHandle; + /** Is training enabled in the native library */ static boolean trainingEnabled; @@ -176,12 +181,13 @@ static synchronized void init() throws IOException { } load(ONNXRUNTIME_JNI_LIBRARY_NAME); - ortApiHandle = initialiseAPIBase(ORT_API_VERSION_14); + ortApiHandle = initialiseAPIBase(ORT_API_VERSION_23); if (ortApiHandle == 0L) { throw new IllegalStateException( "There is a mismatch between the ORT class files and the ORT native library, and the native library could not be loaded"); } - ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_14); + ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_23); + ortCompileApiHandle = initialiseCompileAPIBase(ortApiHandle); trainingEnabled = ortTrainingApiHandle != 0L; providers = initialiseProviders(ortApiHandle); version = initialiseVersion(); @@ -499,6 +505,14 @@ private static EnumSet initialiseProviders(long ortApiHandle) { */ private static native long initialiseTrainingAPIBase(long apiHandle, int apiVersionNumber); + /** + * Get a reference to the compile API struct. + * + * @param apiHandle The ORT API struct pointer. + * @return A pointer to the compile API struct. + */ + private static native long initialiseCompileAPIBase(long apiHandle); + /** * Gets the array of available providers. * diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 8382ef06e26e5..497772baf5357 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -8,7 +8,11 @@ import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; import java.util.EnumSet; +import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.logging.Logger; @@ -442,6 +446,48 @@ public static EnumSet getAvailableProviders() { return OnnxRuntime.providers.clone(); } + /** + * Registers an execution provider library with this OrtEnvironment. + * + * @param registrationName The name to register the library with (used to remove it later with + * {@link #unregisterExecutionProviderLibrary(String)}). + * @param libraryPath The path to the library binary on disk. + * @throws OrtException If the library could not be registered. + */ + public void registerExecutionProviderLibrary(String registrationName, String libraryPath) + throws OrtException { + registerExecutionProviderLibrary( + OnnxRuntime.ortApiHandle, nativeHandle, registrationName, libraryPath); + } + + /** + * Unregisters an execution provider library from this OrtEnvironment. + * + * @param registrationName The name the library was registered under. + * @throws OrtException If the library could not be removed. + */ + public void unregisterExecutionProviderLibrary(String registrationName) throws OrtException { + unregisterExecutionProviderLibrary(OnnxRuntime.ortApiHandle, nativeHandle, registrationName); + } + + /** + * Get the list of all execution provider and device combinations that are available. + * + * @see OrtSession.SessionOptions#addExecutionProvider(List, Map) + * @return The list of execution provider and device combinations. + * @throws OrtException If the devices could not be listed. + */ + public List getEpDevices() throws OrtException { + long[] deviceHandles = getEpDevices(OnnxRuntime.ortApiHandle, nativeHandle); + + List devicesList = new ArrayList<>(); + for (long deviceHandle : deviceHandles) { + devicesList.add(new OrtEpDevice(deviceHandle)); + } + + return Collections.unmodifiableList(devicesList); + } + /** * Creates the native object. * @@ -476,6 +522,40 @@ private static native long createHandle( */ private static native long getDefaultAllocator(long apiHandle) throws OrtException; + /** + * Registers the specified execution provider with this OrtEnvironment. + * + * @param apiHandle The API handle. + * @param nativeHandle The OrtEnvironment handle. + * @param registrationName The name of the execution provider. + * @param libraryPath The path to the execution provider binary. + * @throws OrtException If the registration failed. + */ + private static native void registerExecutionProviderLibrary( + long apiHandle, long nativeHandle, String registrationName, String libraryPath) + throws OrtException; + + /** + * Removes the specified execution provider from this OrtEnvironment. + * + * @param apiHandle The API handle. + * @param nativeHandle The OrtEnvironment handle. + * @param registrationName The name of the execution provider. + * @throws OrtException If the removal failed. + */ + private static native void unregisterExecutionProviderLibrary( + long apiHandle, long nativeHandle, String registrationName) throws OrtException; + + /** + * Gets handles for the EP device tuples available in this OrtEnvironment. + * + * @param apiHandle The API handle to use. + * @param nativeHandle The OrtEnvironment handle. + * @return An array of OrtEpDevice handles. + * @throws OrtException If the call failed. + */ + private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException; + /** * Closes the OrtEnvironment, frees the handle. * diff --git a/java/src/main/java/ai/onnxruntime/OrtEpDevice.java b/java/src/main/java/ai/onnxruntime/OrtEpDevice.java new file mode 100644 index 0000000000000..f63dec1dbaf83 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtEpDevice.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.util.Map; + +/** A tuple of Execution Provider information and the hardware device. */ +public final class OrtEpDevice { + + private final long nativeHandle; + + private final String epName; + private final String epVendor; + private final Map epMetadata; + private final Map epOptions; + private final OrtHardwareDevice device; + + /** + * Construct an OrtEpDevice tuple from the native pointer. + * + * @param nativeHandle The native pointer. + */ + OrtEpDevice(long nativeHandle) { + this.nativeHandle = nativeHandle; + this.epName = getName(OnnxRuntime.ortApiHandle, nativeHandle); + this.epVendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle); + String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle); + this.epMetadata = OrtUtil.convertToMap(metadata); + String[][] options = getOptions(OnnxRuntime.ortApiHandle, nativeHandle); + this.epOptions = OrtUtil.convertToMap(options); + this.device = new OrtHardwareDevice(getDeviceHandle(OnnxRuntime.ortApiHandle, nativeHandle)); + } + + /** + * Return the native pointer. + * + * @return The native pointer. + */ + long getNativeHandle() { + return nativeHandle; + } + + /** + * Gets the EP name. + * + * @return The EP name. + */ + public String getName() { + return epName; + } + + /** + * Gets the vendor name. + * + * @return The vendor name. + */ + public String getVendor() { + return epVendor; + } + + /** + * Gets an unmodifiable view on the EP metadata. + * + * @return The EP metadata. + */ + public Map getMetadata() { + return epMetadata; + } + + /** + * Gets an unmodifiable view on the EP options. + * + * @return The EP options. + */ + public Map getOptions() { + return epOptions; + } + + /** + * Gets the device information. + * + * @return The device information. + */ + public OrtHardwareDevice getDevice() { + return device; + } + + @Override + public String toString() { + return "OrtEpDevice{" + + "epName='" + + epName + + '\'' + + ", epVendor='" + + epVendor + + '\'' + + ", epMetadata=" + + epMetadata + + ", epOptions=" + + epOptions + + ", device=" + + device + + '}'; + } + + private static native String getName(long apiHandle, long nativeHandle); + + private static native String getVendor(long apiHandle, long nativeHandle); + + private static native String[][] getMetadata(long apiHandle, long nativeHandle); + + private static native String[][] getOptions(long apiHandle, long nativeHandle); + + private static native long getDeviceHandle(long apiHandle, long nativeHandle); +} diff --git a/java/src/main/java/ai/onnxruntime/providers/OrtFlags.java b/java/src/main/java/ai/onnxruntime/OrtFlags.java similarity index 88% rename from java/src/main/java/ai/onnxruntime/providers/OrtFlags.java rename to java/src/main/java/ai/onnxruntime/OrtFlags.java index 73d3eeae6499c..f57fd945dbeec 100644 --- a/java/src/main/java/ai/onnxruntime/providers/OrtFlags.java +++ b/java/src/main/java/ai/onnxruntime/OrtFlags.java @@ -1,8 +1,8 @@ /* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ -package ai.onnxruntime.providers; +package ai.onnxruntime; import java.util.EnumSet; diff --git a/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java b/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java new file mode 100644 index 0000000000000..bd99f5599fd14 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.util.Map; +import java.util.logging.Logger; + +/** Hardware information for a specific device. */ +public final class OrtHardwareDevice { + + /** The hardware device types. */ + // Must be updated in concert with the native OrtHardwareDeviceType enum in the C API + public enum OrtHardwareDeviceType { + /** A CPU device. */ + CPU(0), + /** A GPU device. */ + GPU(1), + /** A NPU (Neural Processing Unit) device. */ + NPU(2); + private final int value; + + private static final Logger logger = Logger.getLogger(OrtHardwareDeviceType.class.getName()); + private static final OrtHardwareDeviceType[] values = new OrtHardwareDeviceType[3]; + + static { + for (OrtHardwareDeviceType ot : OrtHardwareDeviceType.values()) { + values[ot.value] = ot; + } + } + + OrtHardwareDeviceType(int value) { + this.value = value; + } + + /** + * Gets the native value associated with this device type. + * + * @return The native value. + */ + public int getValue() { + return value; + } + + /** + * Maps from the C API's int enum to the Java enum. + * + * @param deviceType The index of the Java enum. + * @return The Java enum. + */ + public static OrtHardwareDeviceType mapFromInt(int deviceType) { + if ((deviceType >= 0) && (deviceType < values.length)) { + return values[deviceType]; + } else { + logger.warning("Unknown device type '" + deviceType + "' setting to CPU"); + return CPU; + } + } + } + + private final long nativeHandle; + + private final OrtHardwareDeviceType type; + private final int vendorId; + private final String vendor; + private final int deviceId; + private final Map metadata; + + OrtHardwareDevice(long nativeHandle) { + this.nativeHandle = nativeHandle; + this.type = + OrtHardwareDeviceType.mapFromInt(getDeviceType(OnnxRuntime.ortApiHandle, nativeHandle)); + this.vendorId = getVendorId(OnnxRuntime.ortApiHandle, nativeHandle); + this.vendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle); + this.deviceId = getDeviceId(OnnxRuntime.ortApiHandle, nativeHandle); + String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle); + this.metadata = OrtUtil.convertToMap(metadata); + } + + long getNativeHandle() { + return nativeHandle; + } + + /** + * Gets the device type. + * + * @return The device type. + */ + public OrtHardwareDeviceType getType() { + return type; + } + + /** + * Gets the vendor ID number. + * + * @return The vendor ID number. + */ + public int getVendorId() { + return vendorId; + } + + /** + * Gets the device ID number. + * + * @return The device ID number. + */ + public int getDeviceId() { + return deviceId; + } + + /** + * Gets an unmodifiable view on the device metadata. + * + * @return The device metadata. + */ + public Map getMetadata() { + return metadata; + } + + /** + * Gets the vendor name. + * + * @return The vendor name. + */ + public String getVendor() { + return vendor; + } + + @Override + public String toString() { + return "OrtHardwareDevice{" + + "type=" + + type + + ", vendorId=" + + vendorId + + ", vendor='" + + vendor + + '\'' + + ", deviceId=" + + deviceId + + ", metadata=" + + metadata + + '}'; + } + + private static native String getVendor(long apiHandle, long nativeHandle); + + private static native String[][] getMetadata(long apiHandle, long nativeHandle); + + private static native int getDeviceType(long apiHandle, long nativeHandle); + + private static native int getDeviceId(long apiHandle, long nativeHandle); + + private static native int getVendorId(long apiHandle, long nativeHandle); +} diff --git a/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java b/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java new file mode 100644 index 0000000000000..09b3064b72b93 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.nio.ByteBuffer; +import java.util.EnumSet; + +/** Configuration options for compiling ONNX models. */ +public final class OrtModelCompilationOptions implements AutoCloseable { + /** Flags representing options when compiling a model. */ + public enum OrtCompileApiFlags implements OrtFlags { + /** Default. Do not enable any additional compilation options. */ + NONE(0), + + /** + * Force compilation to return an error (ORT_FAIL) if no nodes were compiled. Otherwise, a model + * with basic optimizations (ORT_ENABLE_BASIC) is still generated by default. + */ + ERROR_IF_NO_NODES_COMPILED(1), + + /** + * Force compilation to return an error (ORT_FAIL) if a file with the same filename as the + * output model exists. Otherwise, compilation will automatically overwrite the output file if + * it exists. + */ + ERROR_IF_OUTPUT_FILE_EXISTS(1 << 1); + + /** The native value of the enum. */ + public final int value; + + OrtCompileApiFlags(int value) { + this.value = value; + } + + @Override + public int getValue() { + return value; + } + } + + private final long nativeHandle; + private boolean closed = false; + + // Used to ensure the byte buffer doesn't get GC'd before the model is compiled. + private ByteBuffer buffer; + + OrtModelCompilationOptions(long nativeHandle) { + this.nativeHandle = nativeHandle; + } + + /** + * Creates a model compilation options from an existing SessionOptions. + * + *

An OrtModelCompilationOptions object contains the settings used to generate a compiled ONNX + * model. The OrtSessionOptions object has the execution providers with which the model will be + * compiled. + * + * @param env The OrtEnvironment. + * @param sessionOptions The session options to use. + * @return A constructed model compilation options instance. + * @throws OrtException If the construction failed. + */ + public static OrtModelCompilationOptions createFromSessionOptions( + OrtEnvironment env, OrtSession.SessionOptions sessionOptions) throws OrtException { + long handle = + createFromSessionOptions( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + env.getNativeHandle(), + sessionOptions.getNativeHandle()); + return new OrtModelCompilationOptions(handle); + } + + /** + * Checks if the OrtModelCompilationOptions is closed, if so throws {@link IllegalStateException}. + */ + private void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OrtModelCompilationOptions."); + } + } + + @Override + public void close() { + if (!closed) { + close(OnnxRuntime.ortCompileApiHandle, nativeHandle); + closed = true; + } else { + throw new IllegalStateException("Trying to close a closed OrtModelCompilationOptions."); + } + } + + /** + * Sets the file path to the input ONNX model. + * + *

The input model's location must be set either to a path on disk with this method, or by + * supplying an in-memory reference with {@link #setInputModelFromBuffer}. + * + * @param inputModelPath The path to the model on disk. + * @throws OrtException If the set failed. + */ + public void setInputModelPath(String inputModelPath) throws OrtException { + checkClosed(); + setInputModelPath( + OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, inputModelPath); + } + + /** + * Uses the supplied buffer as the input ONNX model. + * + *

The input model's location must be set either to an in-memory reference with this method, or + * by supplying a path on disk with {@link #setInputModelPath(String)}. + * + *

If the {@link ByteBuffer} is not direct it is copied into a direct buffer. In either case + * this object holds a reference to the buffer to prevent it from being GC'd. + * + * @param inputModelBuffer The buffer. + * @throws OrtException If the buffer could not be set. + */ + public void setInputModelFromBuffer(ByteBuffer inputModelBuffer) throws OrtException { + checkClosed(); + if (!inputModelBuffer.isDirect()) { + // if it's not a direct buffer, copy it. + buffer = ByteBuffer.allocateDirect(inputModelBuffer.remaining()); + int tmpPos = inputModelBuffer.position(); + buffer.put(inputModelBuffer); + buffer.rewind(); + inputModelBuffer.position(tmpPos); + } else { + buffer = inputModelBuffer; + } + int bufferPos = buffer.position(); + int bufferRemaining = buffer.remaining(); + setInputModelFromBuffer( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + nativeHandle, + buffer, + bufferPos, + bufferRemaining); + } + + /** + * Sets the file path for the output compiled ONNX model. + * + *

If this is unset it will append `_ctx` to the file name, e.g., my_model.onnx becomes + * my_model_ctx.onnx. + * + * @param outputModelPath The output model path. + * @throws OrtException If the path could not be set. + */ + public void setOutputModelPath(String outputModelPath) throws OrtException { + checkClosed(); + setOutputModelPath( + OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, outputModelPath); + } + + /** + * Optionally sets the file that stores initializers for the compiled ONNX model. If unset then + * initializers are stored inside the model. + * + *

Only initializers for nodes that were not compiled are stored in the external initializers + * file. Compiled nodes contain their initializer data within the `ep_cache_context` attribute of + * EPContext nodes. + * + * @see OrtModelCompilationOptions#setEpContextEmbedMode + * @param outputExternalInitializersPath Path to the file. + * @param sizeThreshold Initializers larger than this threshold are stored in the file. + * @throws OrtException If the path could not be set. + */ + public void setOutputExternalInitializersPath( + String outputExternalInitializersPath, long sizeThreshold) throws OrtException { + checkClosed(); + // check positive + setOutputExternalInitializersPath( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + nativeHandle, + outputExternalInitializersPath, + sizeThreshold); + } + + /** + * Enables or disables the embedding of EPContext binary data into the ep_cache_context attribute + * of EPContext nodes. + * + *

Defaults to false. When enabled, the `ep_cache_context` attribute of EPContext nodes will + * store the context binary data, which may include weights for compiled subgraphs. When disabled, + * the `ep_cache_context` attribute of EPContext nodes will contain the path to the file + * containing the context binary data. The path is set by the execution provider creating the + * EPContext node. + * + *

For more details see the EPContext design + * document. + * + * @param embedEpContext True to embed EPContext binary data into the EPContext node's + * ep_cache_context attribute. + * @throws OrtException If the set operation failed. + */ + public void setEpContextEmbedMode(boolean embedEpContext) throws OrtException { + checkClosed(); + setEpContextEmbedMode( + OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, embedEpContext); + } + + /** + * Sets the specified compilation flags. + * + * @param flags The compilation flags. + * @throws OrtException If the set operation failed. + */ + public void setCompilationFlags(EnumSet flags) throws OrtException { + checkClosed(); + setCompilationFlags( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + nativeHandle, + OrtFlags.aggregateToInt(flags)); + } + + /** + * Compiles the ONNX model with the configuration described by this instance of + * OrtModelCompilationOptions. + * + * @throws OrtException If the compilation failed. + */ + public void compileModel() throws OrtException { + checkClosed(); + // Safe as the environment must exist to create one of these objects. + OrtEnvironment env = OrtEnvironment.getEnvironment(); + compileModel( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + env.getNativeHandle(), + nativeHandle); + } + + private static native long createFromSessionOptions( + long apiHandle, long compileApiHandle, long envHandle, long nativeHandle) throws OrtException; + + private static native void close(long compileApiHandle, long nativeHandle); + + private static native void setInputModelPath( + long apiHandle, long compileApiHandle, long nativeHandle, String inputModelPath) + throws OrtException; + + private static native void setInputModelFromBuffer( + long apiHandle, + long compileApiHandle, + long nativeHandle, + ByteBuffer inputBuffer, + long bufferPos, + long bufferRemaining) + throws OrtException; + + private static native void setOutputModelPath( + long apiHandle, long compileApiHandle, long nativeHandle, String outputModelPath) + throws OrtException; + + private static native void setOutputExternalInitializersPath( + long apiHandle, + long compileApiHandle, + long nativeHandle, + String externalInitializersPath, + long sizeThreshold) + throws OrtException; + + private static native void setEpContextEmbedMode( + long apiHandle, long compileApiHandle, long nativeHandle, boolean embedEpContext) + throws OrtException; + + private static native void setCompilationFlags( + long apiHandle, long compileApiHandle, long nativeHandle, int flags) throws OrtException; + + private static native void compileModel( + long apiHandle, long compileApiHandle, long envHandle, long nativeHandle) throws OrtException; +} diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index a399d5080ca16..42dc90b71cb80 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates * Licensed under the MIT License. */ @@ -8,7 +8,6 @@ import ai.onnxruntime.providers.CoreMLFlags; import ai.onnxruntime.providers.NNAPIFlags; import ai.onnxruntime.providers.OrtCUDAProviderOptions; -import ai.onnxruntime.providers.OrtFlags; import ai.onnxruntime.providers.OrtTensorRTProviderOptions; import java.io.IOException; import java.nio.ByteBuffer; @@ -624,6 +623,10 @@ private native OnnxModelMetadata constructMetadata( *

Used to set the number of threads, optimisation level, computation backend and other * options. * + *

The order execution providers are added to an options instance is the order they will be + * considered for op node assignment, with the EP added first having priority. The CPU EP is a + * fallback and added by default. + * *

Modifying this after the session has been constructed will have no effect. * *

The SessionOptions object must not be closed until all sessions which use it are closed, as @@ -730,7 +733,7 @@ public SessionOptions() { @Override public void close() { if (!closed) { - if (customLibraryHandles.size() > 0) { + if (!customLibraryHandles.isEmpty()) { long[] longArray = new long[customLibraryHandles.size()]; for (int i = 0; i < customLibraryHandles.size(); i++) { longArray[i] = customLibraryHandles.get(i); @@ -917,10 +920,10 @@ public void registerCustomOpLibrary(String path) throws OrtException { * *

 OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); * - *

See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for more - * information on custom ops. See - * https://github.com/microsoft/onnxruntime/blob/342a5bf2b756d1a1fc6fdc582cfeac15182632fe/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc#L115 - * for an example of a custom op library registration function. + *

See Add + * Custom Op for more information on custom ops. See an example of a custom op library + * registration function here. * * @param registrationFuncName The name of the registration function to call. * @throws OrtException If there was an error finding or calling the registration function. @@ -1273,10 +1276,47 @@ public void addCoreML(EnumSet flags) throws OrtException { addCoreML(OnnxRuntime.ortApiHandle, nativeHandle, OrtFlags.aggregateToInt(flags)); } + /** + * Adds the specified execution provider and device tuples as an execution backend. + * + *

Execution provider priority is in the order added, i.e., the first provider added to a + * session options will be used first for op node assignment. + * + * @param devices The EP and device tuples. Each element must use the same EP, though they can + * use different devices. + * @param providerOptions Configuration options for the execution provider. Refer to the + * specific execution provider's documentation. + * @throws OrtException If there was an error in native code. + */ + public void addExecutionProvider(List devices, Map providerOptions) + throws OrtException { + checkClosed(); + if (devices.isEmpty()) { + throw new IllegalArgumentException("Must supply at least one OrtEpDevice"); + } + long[] deviceHandles = new long[devices.size()]; + for (int i = 0; i < devices.size(); i++) { + deviceHandles[i] = devices.get(i).getNativeHandle(); + } + String[][] optsArray = OrtUtil.unpackMap(providerOptions); + // This is valid as the environment must have been created to create the OrtEpDevice list. + long envHandle = OrtEnvironment.getEnvironment().getNativeHandle(); + addExecutionProvider( + OnnxRuntime.ortApiHandle, + envHandle, + nativeHandle, + deviceHandles, + optsArray[0], + optsArray[1]); + } + /** * Adds the named execution provider (backend) as an execution backend. This generic function * only allows a subset of execution providers. * + *

Execution provider priority is in the order added, i.e., the first provider added to a + * session options will be used first for op node assignment. + * * @param providerName The name of the execution provider. * @param providerOptions Configuration options for the execution provider. Refer to the * specific execution provider's documentation. @@ -1285,20 +1325,9 @@ public void addCoreML(EnumSet flags) throws OrtException { private void addExecutionProvider(String providerName, Map providerOptions) throws OrtException { checkClosed(); - String[] providerOptionKey = new String[providerOptions.size()]; - String[] providerOptionVal = new String[providerOptions.size()]; - int i = 0; - for (Map.Entry entry : providerOptions.entrySet()) { - providerOptionKey[i] = entry.getKey(); - providerOptionVal[i] = entry.getValue(); - i++; - } + String[][] optsArray = OrtUtil.unpackMap(providerOptions); addExecutionProvider( - OnnxRuntime.ortApiHandle, - nativeHandle, - providerName, - providerOptionKey, - providerOptionVal); + OnnxRuntime.ortApiHandle, nativeHandle, providerName, optsArray[0], optsArray[1]); } /** @@ -1484,6 +1513,15 @@ private native void addExecutionProvider( String[] providerOptionKey, String[] providerOptionVal) throws OrtException; + + private native void addExecutionProvider( + long apiHandle, + long envHandle, + long nativeHandle, + long[] deviceHandles, + String[] providerOptionKey, + String[] providerOptionVal) + throws OrtException; } /** Used to control logging and termination of a call to {@link OrtSession#run}. */ diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 2f44236e4ef67..ee91fdb292baa 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ @@ -16,6 +16,9 @@ import java.nio.ShortBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.logging.Logger; /** Util code for interacting with Java arrays. */ @@ -370,6 +373,52 @@ public static boolean validateShape(long[] shape) { return valid && shape.length <= TensorInfo.MAX_DIMENSIONS; } + /** + * Converts the output of a OrtKeyValuePairs into a Java unmodifiable HashMap. + * + * @param zippedString The zipped keys and values. + * @return An unmodifiable Map. + */ + static Map convertToMap(String[][] zippedString) { + if (zippedString.length != 2) { + throw new IllegalArgumentException("Invalid zipped string, must have two arrays."); + } else if (zippedString[0].length != zippedString[1].length) { + throw new IllegalArgumentException( + "Invalid zipped string, must have two arrays of the same length."); + } + Map map = new HashMap<>(capacityFromSize(zippedString[0].length)); + for (int i = 0; i < zippedString[0].length; i++) { + map.put(zippedString[0][i], zippedString[1][i]); + } + return Collections.unmodifiableMap(map); + } + + /** + * Converts a Java string map into a pair of arrays suitable for constructing a native + * OrtKeyValuePairs object. + * + * @param map A map from string to string, with no null keys or values. + * @return A pair of String arrays. + */ + static String[][] unpackMap(Map map) { + String[] keys = new String[map.size()]; + String[] values = new String[map.size()]; + int i = 0; + for (Map.Entry entry : map.entrySet()) { + if (entry.getKey() == null || entry.getValue() == null) { + throw new IllegalArgumentException( + "Invalid map, keys and values must not be null, found key = " + + entry.getKey() + + ", value = " + + entry.getValue()); + } + keys[i] = entry.getKey(); + values[i] = entry.getValue(); + i++; + } + return new String[][] {keys, values}; + } + /** * Flatten a multidimensional String array into a single dimensional String array, reading it in a * multidimensional row-major order. diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index 22bf940844774..15fe459dad7c8 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,9 +1,11 @@ /* - * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; +import ai.onnxruntime.OrtFlags; + /** Flags for the CoreML provider. */ public enum CoreMLFlags implements OrtFlags { /** diff --git a/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java b/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java index eeaf6cc8d53bc..dd30684078717 100644 --- a/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java @@ -1,9 +1,11 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; +import ai.onnxruntime.OrtFlags; + /** Flags for the NNAPI provider. */ public enum NNAPIFlags implements OrtFlags { /** Enables fp16 support. */ diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 5d8efd7b476cb..96ea8e79bc978 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1014,6 +1014,36 @@ jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca } } +jobjectArray convertOrtKeyValuePairsToArrays(JNIEnv *jniEnv, const OrtApi * api, const OrtKeyValuePairs * kvp) { + // extract pair arrays + const char* const* keys = NULL; + const char* const* values = NULL; + size_t numKeys = 0; + api->GetKeyValuePairs(kvp, &keys, &values, &numKeys); + jsize jNumKeys = safecast_size_t_to_jsize(numKeys); + + // create Java String[] + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); + jobjectArray keyArray = (*jniEnv)->NewObjectArray(jniEnv, jNumKeys, stringClazz, NULL); + jobjectArray valueArray = (*jniEnv)->NewObjectArray(jniEnv, jNumKeys, stringClazz, NULL); + + // populate Java arrays + for (jsize i = 0; i < jNumKeys; i++) { + jstring key = (*jniEnv)->NewStringUTF(jniEnv, keys[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, keyArray, i, key); + jstring value = (*jniEnv)->NewStringUTF(jniEnv, values[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, valueArray, i, value); + } + + // create Java String[][] + jclass stringArrClazz = (*jniEnv)->GetObjectClass(jniEnv, keyArray); + jobjectArray pair = (*jniEnv)->NewObjectArray(jniEnv, 2, stringArrClazz, 0); + (*jniEnv)->SetObjectArrayElement(jniEnv, pair, 0, keyArray); + (*jniEnv)->SetObjectArrayElement(jniEnv, pair, 1, valueArray); + + return pair; +} + jint throwOrtException(JNIEnv *jniEnv, int messageId, const char *message) { jstring messageStr = (*jniEnv)->NewStringUTF(jniEnv, message); diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 7f41e06371f2a..040fd41264c10 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -78,6 +78,8 @@ jobject createMapInfoFromValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* onnxValue); +jobjectArray convertOrtKeyValuePairsToArrays(JNIEnv *jniEnv, const OrtApi * api, const OrtKeyValuePairs * kvp); + jint throwOrtException(JNIEnv *env, int messageId, const char *message); jint convertErrorCode(OrtErrorCode code); diff --git a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c index 659f34e1fb66f..d8f5f1a3cb2db 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c +++ b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c @@ -32,6 +32,19 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseTrainingAPIBas return (jlong) trainingApi; } +/* + * Class: ai_onnxruntime_OnnxRuntime + * Method: initialiseCompileAPIBase + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseCompileAPIBase + (JNIEnv * jniEnv, jclass clazz, jlong apiHandle) { + (void)jniEnv; (void)clazz; // required JNI parameters not needed by functions which don't call back into Java. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = api->GetCompileApi(); + return (jlong) compileApi; +} + /* * Class: ai_onnxruntime_OnnxRuntime * Method: getAvailableProviders diff --git a/java/src/main/native/ai_onnxruntime_OrtEnvironment.c b/java/src/main/native/ai_onnxruntime_OrtEnvironment.c index e1b1ff1c05fe1..77b096d62ec76 100644 --- a/java/src/main/native/ai_onnxruntime_OrtEnvironment.c +++ b/java/src/main/native/ai_onnxruntime_OrtEnvironment.c @@ -60,6 +60,76 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtEnvironment_getDefaultAllocator return (jlong)allocator; } +/* + * Class: ai_onnxruntime_OrtEnvironment + * Method: registerExecutionProviderLibrary + * Signature: (JJLjava/lang/String;Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtEnvironment_registerExecutionProviderLibrary + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle, jstring name, jstring libraryPath) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEnv* env = (OrtEnv*) nativeHandle; + const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL); +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, libraryPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, libraryPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, libraryPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, api->RegisterExecutionProviderLibrary(env, cName, newString)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, libraryPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, libraryPath, NULL); + checkOrtStatus(jniEnv, api, api->RegisterExecutionProviderLibrary(env, cName, cPath)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, libraryPath, cPath); +#endif + (*jniEnv)->ReleaseStringUTFChars(jniEnv, name, cName); +} + +/* + * Class: ai_onnxruntime_OrtEnvironment + * Method: unregisterExecutionProviderLibrary + * Signature: (JJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtEnvironment_unregisterExecutionProviderLibrary + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle, jstring name) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEnv* env = (OrtEnv*) nativeHandle; + const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL); + checkOrtStatus(jniEnv, api, api->UnregisterExecutionProviderLibrary(env, cName)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, name, cName); +} + +/* + * Class: ai_onnxruntime_OrtEnvironment + * Method: getEpDevices + * Signature: (JJ)[J + */ +JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OrtEnvironment_getEpDevices + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEnv* env = (OrtEnv*) nativeHandle; + size_t numDevices = 0; + const OrtEpDevice* const* devicesArr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetEpDevices(env, &devicesArr, &numDevices)); + if (code != ORT_OK) { + return NULL; + } else { + jsize numDevicesInt = safecast_size_t_to_jsize(numDevices); + jlongArray outputArr = (*jniEnv)->NewLongArray(jniEnv, numDevicesInt); + (*jniEnv)->SetLongArrayRegion(jniEnv, outputArr, 0, numDevicesInt, (jlong*)devicesArr); + return outputArr; + } +} + /* * Class: ai_onnxruntime_OrtEnvironment * Method: close diff --git a/java/src/main/native/ai_onnxruntime_OrtEpDevice.c b/java/src/main/native/ai_onnxruntime_OrtEpDevice.c new file mode 100644 index 0000000000000..5a1e3092b0fb9 --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtEpDevice.c @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtEpDevice.h" + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getName + * Signature: (JJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getName + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const char* name = api->EpDevice_EpName(epDevice); + jstring nameStr = (*jniEnv)->NewStringUTF(jniEnv, name); + return nameStr; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getVendor + * Signature: (JJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getVendor + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const char* vendor = api->EpDevice_EpVendor(epDevice); + jstring vendorStr = (*jniEnv)->NewStringUTF(jniEnv, vendor); + return vendorStr; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getMetadata + * Signature: (JJ)[[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getMetadata + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const OrtKeyValuePairs* kvp = api->EpDevice_EpMetadata(epDevice); + jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); + return pair; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getOptions + * Signature: (JJ)[[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getOptions + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const OrtKeyValuePairs* kvp = api->EpDevice_EpOptions(epDevice); + jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); + return pair; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getDeviceHandle + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtEpDevice_getDeviceHandle + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const OrtHardwareDevice* device = api->EpDevice_Device(epDevice); + return (jlong) device; +} diff --git a/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c b/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c new file mode 100644 index 0000000000000..3191a89c26ba1 --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtHardwareDevice.h" + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getVendor + * Signature: (JJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getVendor + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + const char* vendor = api->HardwareDevice_Vendor(device); + jstring vendorStr = (*jniEnv)->NewStringUTF(jniEnv, vendor); + return vendorStr; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getMetadata + * Signature: (JJ)[[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getMetadata + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + const OrtKeyValuePairs* kvp = api->HardwareDevice_Metadata(device); + jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); + return pair; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getDeviceType + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getDeviceType + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + OrtHardwareDeviceType type = api->HardwareDevice_Type(device); + jint output = 0; + // Must be kept aligned with the Java OrtHardwareDeviceType enum. + switch (type) { + case OrtHardwareDeviceType_CPU: + output = 0; + break; + case OrtHardwareDeviceType_GPU: + output = 1; + break; + case OrtHardwareDeviceType_NPU: + output = 2; + break; + default: + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Unexpected device type found. Only CPU, GPU and NPU are supported."); + break; + } + return output; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getDeviceId + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getDeviceId + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + uint32_t id = api->HardwareDevice_DeviceId(device); + return (jint) id; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getVendorId + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getVendorId + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + uint32_t id = api->HardwareDevice_VendorId(device); + return (jint) id; +} diff --git a/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c b/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c new file mode 100644 index 0000000000000..4f79383d09766 --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtModelCompilationOptions.h" + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: createFromSessionOptions + * Signature: (JJJJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_createFromSessionOptions + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong envHandle, jlong sessionOptionsHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + const OrtEnv* env = (const OrtEnv*)envHandle; + const OrtSessionOptions* sessionOptions = (const OrtSessionOptions*) sessionOptionsHandle; + OrtModelCompilationOptions* output = NULL; + checkOrtStatus(jniEnv, api, compileApi->CreateModelCompilationOptionsFromSessionOptions(env, sessionOptions, &output)); + return (jlong) output; +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: close + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_close + (JNIEnv * jniEnv, jclass jclazz, jlong compileApiHandle, jlong nativeHandle) { + (void)jniEnv; (void)jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + compileApi->ReleaseModelCompilationOptions((OrtModelCompilationOptions *)nativeHandle); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setInputModelPath + * Signature: (JJJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setInputModelPath + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring modelPath) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelPath(compOpts, newString)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelPath(compOpts, cPath)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, modelPath, cPath); +#endif +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setInputModelFromBuffer + * Signature: (JJJLjava/nio/ByteBuffer;JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setInputModelFromBuffer + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jobject buffer, jlong bufferPos, jlong bufferRemaining) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + // Cast to pointers + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + + // Extract the buffer + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelFromBuffer(compOpts, bufferArr, bufferRemaining)); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setOutputModelPath + * Signature: (JJJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setOutputModelPath + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring outputPath) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, outputPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, outputPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, outputPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelPath(compOpts, newString)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, outputPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, outputPath, NULL); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelPath(compOpts, cPath)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, outputPath, cPath); +#endif +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setOutputExternalInitializersPath + * Signature: (JJJLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setOutputExternalInitializersPath + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring initializersPath, jlong threshold) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, initializersPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, initializersPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, initializersPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelExternalInitializersFile(compOpts, newString, threshold)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, initializersPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, initializersPath, NULL); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelExternalInitializersFile(compOpts, cPath, threshold)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, initializersPath, cPath); +#endif +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setEpContextEmbedMode + * Signature: (JJJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setEpContextEmbedMode + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jboolean embedMode) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetEpContextEmbedMode(compOpts, (bool) embedMode)); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setCompilationFlags + * Signature: (JJJI)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setCompilationFlags + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jint flags) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetFlags(compOpts, flags)); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: compileModel + * Signature: (JJJJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_compileModel + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong envHandle, jlong nativeHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + const OrtEnv* env = (const OrtEnv*)envHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + checkOrtStatus(jniEnv, api, compileApi->CompileModel(env, compOpts)); +} diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index ff6b7fa703e6e..95bcdf7af9746 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -718,11 +718,11 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addROC } /* - * Class:: ai_onnxruntime_OrtSession_SessionOptions + * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: addExecutionProvider - * Signature: (JILjava/lang/String)V + * Signature: (JJLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider( +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider__JJLjava_lang_String_2_3Ljava_lang_String_2_3Ljava_lang_String_2( JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring jepName, jobjectArray configKeyArr, jobjectArray configValueArr) { (void)jobj; @@ -756,3 +756,50 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExe free((void*)jkeyArray); free((void*)jvalueArray); } + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: addExecutionProvider + * Signature: (JJJ[J[Ljava/lang/String;[Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider__JJJ_3J_3Ljava_lang_String_2_3Ljava_lang_String_2 + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jlong optionsHandle, jlongArray deviceHandleArr, jobjectArray configKeyArr, jobjectArray configValueArr) { + (void)jobj; + + const OrtApi* api = (const OrtApi*)apiHandle; + OrtEnv* env = (OrtEnv*) envHandle; + OrtSessionOptions* options = (OrtSessionOptions*)optionsHandle; + jsize deviceCount = (*jniEnv)->GetArrayLength(jniEnv, deviceHandleArr); + jsize keyCount = (*jniEnv)->GetArrayLength(jniEnv, configKeyArr); + + const char** keyArray = (const char**)allocarray(keyCount, sizeof(const char*)); + const char** valueArray = (const char**)allocarray(keyCount, sizeof(const char*)); + jstring* jkeyArray = (jstring*)allocarray(keyCount, sizeof(jstring)); + jstring* jvalueArray = (jstring*)allocarray(keyCount, sizeof(jstring)); + const OrtEpDevice** devicePtrs = allocarray(deviceCount, sizeof(OrtEpDevice *)); + + jlong* deviceHandleElements = (*jniEnv)->GetLongArrayElements(jniEnv, deviceHandleArr, NULL); + for (jsize i = 0; i < deviceCount; i++) { + devicePtrs[i] = (OrtEpDevice*) deviceHandleElements[i]; + } + (*jniEnv)->ReleaseLongArrayElements(jniEnv, deviceHandleArr, deviceHandleElements, JNI_ABORT); + + for (jsize i = 0; i < keyCount; i++) { + jkeyArray[i] = (jstring)((*jniEnv)->GetObjectArrayElement(jniEnv, configKeyArr, i)); + jvalueArray[i] = (jstring)((*jniEnv)->GetObjectArrayElement(jniEnv, configValueArr, i)); + keyArray[i] = (*jniEnv)->GetStringUTFChars(jniEnv, jkeyArray[i], NULL); + valueArray[i] = (*jniEnv)->GetStringUTFChars(jniEnv, jvalueArray[i], NULL); + } + + checkOrtStatus(jniEnv, api, api->SessionOptionsAppendExecutionProvider_V2(options, env, devicePtrs, deviceCount, keyArray, valueArray, keyCount)); + + for (jsize i = 0; i < keyCount; i++) { + (*jniEnv)->ReleaseStringUTFChars(jniEnv, jkeyArray[i], keyArray[i]); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, jvalueArray[i], valueArray[i]); + } + free((void*)devicePtrs); + free((void*)keyArray); + free((void*)valueArray); + free((void*)jkeyArray); + free((void*)jvalueArray); +} diff --git a/java/src/test/java/ai/onnxruntime/CompileApiTest.java b/java/src/test/java/ai/onnxruntime/CompileApiTest.java new file mode 100644 index 0000000000000..b70f4dca5cbd0 --- /dev/null +++ b/java/src/test/java/ai/onnxruntime/CompileApiTest.java @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import ai.onnxruntime.OrtSession.SessionOptions; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** Test for the compilation API. */ +public class CompileApiTest { + private final OrtEnvironment env = OrtEnvironment.getEnvironment(); + + @Test + public void basicUsage() throws OrtException, IOException { + SessionOptions so = new SessionOptions(); + try (OrtModelCompilationOptions compileOptions = + OrtModelCompilationOptions.createFromSessionOptions(env, so)) { + // mainly checking these don't throw which ensures all the plumbing for the binding works. + compileOptions.setInputModelPath("model.onnx"); + compileOptions.setOutputModelPath("compiled_model.onnx"); + + compileOptions.setOutputExternalInitializersPath("external_data.bin", 512); + compileOptions.setEpContextEmbedMode(true); + } + + try (OrtModelCompilationOptions compileOptions = + OrtModelCompilationOptions.createFromSessionOptions(env, so)) { + Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); + byte[] modelBytes = Files.readAllBytes(modelPath); + ByteBuffer modelBuffer = ByteBuffer.wrap(modelBytes); + compileOptions.setInputModelFromBuffer(modelBuffer); + compileOptions.setOutputModelPath("compiled_model.onnx"); + + File f = new File("compiled_model.onnx"); + + compileOptions.compileModel(); + + // Check the compiled model is valid + try (OrtSession session = env.createSession(f.toString(), so)) { + Assertions.assertNotNull(session); + } + + f.delete(); + } + } +} diff --git a/java/src/test/java/ai/onnxruntime/EpDeviceTest.java b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java new file mode 100644 index 0000000000000..ec4c977508c8c --- /dev/null +++ b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import ai.onnxruntime.OrtHardwareDevice.OrtHardwareDeviceType; +import ai.onnxruntime.OrtSession.SessionOptions; +import java.io.File; +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; + +/** Tests for {@link OrtEpDevice} and {@link OrtHardwareDevice}. */ +@EnabledOnOs(value = OS.WINDOWS) +public class EpDeviceTest { + private final OrtEnvironment ortEnv = OrtEnvironment.getEnvironment(); + + private void readHardwareDeviceValues(OrtHardwareDevice device) { + OrtHardwareDeviceType type = device.getType(); + + Assertions.assertTrue( + type == OrtHardwareDeviceType.CPU + || type == OrtHardwareDeviceType.GPU + || type == OrtHardwareDeviceType.NPU); + + if (type == OrtHardwareDeviceType.CPU) { + Assertions.assertFalse(device.getVendor().isEmpty()); + } else { + Assertions.assertTrue(device.getVendorId() != 0); + Assertions.assertTrue(device.getDeviceId() != 0); + } + + Map metadata = device.getMetadata(); + Assertions.assertNotNull(metadata); + for (Map.Entry kvp : metadata.entrySet()) { + Assertions.assertFalse(kvp.getKey().isEmpty()); + } + } + + @Test + public void getEpDevices() throws OrtException { + List epDevices = ortEnv.getEpDevices(); + Assertions.assertNotNull(epDevices); + Assertions.assertFalse(epDevices.isEmpty()); + for (OrtEpDevice epDevice : epDevices) { + Assertions.assertFalse(epDevice.getName().isEmpty()); + Assertions.assertFalse(epDevice.getVendor().isEmpty()); + Map metadata = epDevice.getMetadata(); + Assertions.assertNotNull(metadata); + Map options = epDevice.getOptions(); + Assertions.assertNotNull(options); + readHardwareDeviceValues(epDevice.getDevice()); + } + } + + @Test + public void registerUnregisterLibrary() throws OrtException { + String libFullPath = TestHelpers.getResourcePath("/example_plugin_ep.dll").toString(); + Assertions.assertTrue( + new File(libFullPath).exists(), "Expected lib " + libFullPath + " does not exist."); + + // example plugin ep uses the registration name as the ep name + String epName = "java_ep"; + + // register. shouldn't throw + ortEnv.registerExecutionProviderLibrary(epName, libFullPath); + + // check OrtEpDevice was found + List epDevices = ortEnv.getEpDevices(); + boolean found = epDevices.stream().anyMatch(a -> a.getName().equals(epName)); + Assertions.assertTrue(found); + + // unregister + ortEnv.unregisterExecutionProviderLibrary(epName); + } + + @Test + public void appendToSessionOptionsV2() { + Consumer>> runTest = + (Supplier> options) -> { + try (SessionOptions sessionOptions = new SessionOptions()) { + sessionOptions.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE); + + List epDevices = ortEnv.getEpDevices(); + + // cpu ep ignores the provider options so we can use any value in epOptions and it won't + // break. + List selectedEpDevices = + epDevices.stream() + .filter(a -> a.getName().equals("CPUExecutionProvider")) + .collect(Collectors.toList()); + + Map epOptions = options.get(); + sessionOptions.addExecutionProvider(selectedEpDevices, epOptions); + + Path model = TestHelpers.getResourcePath("/squeezenet.onnx"); + String modelPath = model.toString(); + + // session should load successfully + try (OrtSession session = ortEnv.createSession(modelPath, sessionOptions)) { + Assertions.assertNotNull(session); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + }; + + // empty options + runTest.accept(Collections::emptyMap); + + // dummy options + runTest.accept(() -> Collections.singletonMap("random_key", "value")); + } +} diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 84ed3457a488b..8db91f792cb06 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -15,7 +15,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { // create ONNX runtime env Ort::InitApi(); ORT_NAPI_THROW_ERROR_IF( - Ort::Global::api_ == nullptr, env, + &Ort::GetApi() == nullptr, env, "Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version " "ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library)."); diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 0d5117709c18a..bfa450f4287f8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -280,6 +280,18 @@ class GQAAttentionBase { output, static_cast(present_buffer_sequence_length), nullptr); } + // Pre-allocate buffer for attention mask to avoid allocating it for every processed token + float* attention_bias_thread_fp32 = nullptr; + if (attention_bias_thread != nullptr) { + if constexpr (!std::is_same_v) { + static_assert(std::is_same_v && std::is_same_v); + + size_t bytes = attention_total_seqlen * sizeof(float); + attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); + } + } + BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); + // compute Softmax U* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { @@ -316,9 +328,6 @@ class GQAAttentionBase { static_cast(window_size)); } else { static_assert(std::is_same_v && std::is_same_v); - size_t bytes = window_size * sizeof(float); - auto attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); - BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); MlasConvertHalfToFloatBuffer(attention_bias_thread + start_offset, attention_bias_thread_fp32, window_size); ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread_fp32, static_cast(window_size)); diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 9b35a40f64f2a..5c6c3b919b572 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -331,7 +331,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t token_idx = route_idx / k_; const float weight = route_scale[route_idx]; - float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + token_idx * hidden_size; + const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); + if (buffer_offset + static_cast(hidden_size) > output_buffer_size) { + // Skip this token to prevent buffer overflow + continue; + } + + float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; for (int64_t j = 0; j < hidden_size; ++j) { dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); @@ -344,8 +350,9 @@ Status QMoECPU::Compute(OpKernelContext* context) const { auto accumulate = [&](float* buffer) { memset(buffer, 0, output_buffer_size * sizeof(float)); for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; for (size_t j = 0; j < output_buffer_size; ++j) { - buffer[j] += thread_local_outputs[static_cast(i) * output_buffer_size + j]; + buffer[j] += thread_local_outputs[thread_offset + j]; } } }; diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 85a2cbaea0e44..36a6f70cc69d9 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -200,6 +200,19 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available); + // Kleidi dynamic path requires strictly positive, finite scales. + // Disable if any invalid scale is detected. + if (can_use_dynamic_quant_mlas_) { + const auto bs = b_scale_tensor->DataAsSpan(); + const bool has_invalid = + std::any_of(bs.begin(), bs.end(), + [](float s) { return !std::isfinite(s) || s <= 0.0f; }); + + if (has_invalid) { + can_use_dynamic_quant_mlas_ = false; + } + } + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. // We check that here too before attempting to use them. if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { @@ -379,7 +392,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { if (y->Shape().Size() == 0) return Status::OK(); - auto a_data = static_cast(ctx->Input(IN_A)->DataRaw()); + const float* a_data = ctx->Input(IN_A)->Data(); auto* y_data = y->MutableData(); // batch gemm @@ -393,7 +406,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) { auto& params = gemm_data_vec[gemm_idx]; - params.A = reinterpret_cast(a_data + helper.LeftOffsets()[gemm_idx]); + params.A = a_data + helper.LeftOffsets()[gemm_idx]; params.lda = gemm_shape.K; params.PackedB = packed_b_.get(); params.C = y_data + helper.OutputOffsets()[gemm_idx]; diff --git a/onnxruntime/core/common/cpuid_arch_definition.h b/onnxruntime/core/common/cpuid_arch_definition.h index a541eb66d8ba3..5946b8ca27067 100644 --- a/onnxruntime/core/common/cpuid_arch_definition.h +++ b/onnxruntime/core/common/cpuid_arch_definition.h @@ -9,6 +9,6 @@ #define CPUIDINFO_ARCH_X86 #endif -#if defined(_M_ARM64) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) +#if defined(_M_ARM64) || defined(_M_ARM64EC) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) #define CPUIDINFO_ARCH_ARM #endif // ARM or ARM64 diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index b99c22edb36c8..2ef7c4a9091f3 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -252,16 +252,6 @@ struct OrtNode { /// A status indicating success or an error. virtual onnxruntime::Status GetAttributes(gsl::span attrs) const = 0; - ///

- /// Gets the node's 'TENSOR' attribute as an OrtValue. - /// - /// Node's 'TENSOR' attribute. - /// Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value, - /// only if the attribute is of type 'TENSOR' - /// A status indicating success or an error. - virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, - OrtValue*& value) const = 0; - /// /// Gets the number of node subgraphs. /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 759a2998ace3a..92eb31f0ad385 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -249,32 +249,6 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } -Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const { - const auto* attr_proto = reinterpret_cast(attribute); - - if (attr_proto->type() != onnx::AttributeProto::TENSOR) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); - } - - const auto& graph_viewer = ep_graph_->GetGraphViewer(); - const auto& tensor_proto = attr_proto->t(); - - // Check that TensorProto is valid. - ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type."); - ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type."); - ORT_ENFORCE(!utils::HasExternalData(tensor_proto), - "Tensor proto with external data for value attribute is not supported."); - - // Initialize OrtValue for tensor attribute. - auto tensor_attribute_value = std::make_unique(); - AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, - tensor_attribute_allocator, *tensor_attribute_value)); - - result = tensor_attribute_value.release(); - return Status::OK(); -} - Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { num_subgraphs = subgraphs_.size(); return Status::OK(); @@ -353,6 +327,9 @@ static Status GetInputIndices(const EpNode& consumer_node, [&found, &value_info_name, &indices](gsl::span input_value_infos, bool is_implicit) -> void { for (size_t i = 0; i < input_value_infos.size(); i++) { + if (input_value_infos[i] == nullptr) { // input_value_info == nullptr means the input is optional + continue; + } if (input_value_infos[i]->GetName() == value_info_name) { indices.push_back(is_implicit ? -1 : static_cast(i)); found = true; diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 7f22e265129f7..e003f02a79a2d 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -183,9 +183,6 @@ struct EpNode : public OrtNode { // Gets the node's attributes. Status GetAttributes(gsl::span attrs) const override; - Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, - OrtValue*& attr_tensor) const override; - // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index e7ffcbc7e4c90..2c0f6d6174303 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -138,11 +138,6 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } - Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); - } - Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index caa445b71e2a5..c579ff1542eb9 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -153,28 +153,23 @@ ArmKleidiAI::MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - if(TransA == CblasTrans) - { - return false; + if (M == 0 || N == 0) { + return true; } - if (TransA == CblasNoTrans && K == 0) { - if (Data->beta != 1.0f) { + + if (Data->alpha == 0.0f || K == 0) { + if (Data->beta == 0.0f) { + for (size_t i = 0; i < M; ++i) { + std::fill_n(Data->C + i * Data->ldc, N, 0.0f); + } + } else if (Data->beta != 1.0f) { for (size_t i = 0; i < M; ++i) { for (size_t j = 0; j < N; ++j) { Data->C[i * Data->ldc + j] *= Data->beta; } } } - } - if (Data->beta == 0.0f){ - std::fill_n(Data->C, M * Data->ldc, 0.0f); - } - //Fallback in the case of unsupported cases - if (M == 0 || N == 0 || K == 0 || - TransA != CblasNoTrans || - (TransB != CblasNoTrans && !Data[0].BIsPacked)) - { - return false; + return true; } if (TransA == CblasNoTrans) { @@ -185,11 +180,9 @@ ArmKleidiAI::MlasGemmBatch( auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - if (M < m_step || N < n_step) { - if (GetMlasPlatform().MlasGemmBatchOverride != ArmKleidiAI::MlasGemmBatch){ - //Fallback to MLAS - return false; - } + if (M < m_step && N < n_step && !Data->BIsPacked) { + // Fallback to MLAS + return false; } std::vector KaiPackedData; @@ -316,7 +309,7 @@ ArmKleidiAI::MlasGemmBatch( float* dst_tile = reinterpret_cast(CTile); // quick copy of data in cases where we are not scaling or accumulating anything - // with bounds checking on tile sizing to ensure the data fits in the memory block + // with bounds checking on tile sizing to ensure the data fits in the memory block bool can_memcpy = ( Data[BIdx].alpha == 1.0f && Data[BIdx].beta == 0.0f && @@ -328,21 +321,37 @@ ArmKleidiAI::MlasGemmBatch( if (can_memcpy) { std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); - }else { - // apply alpha scaling and beta to output files - for (size_t i = 0; i < TileSizeM; ++i) { - for (size_t j = 0; j < TileSizeN; ++j) { - const size_t idx = i * TileSizeN + j; - const size_t dst_idx = i * Data[BIdx].ldc + j; - - float ab = temp_tile[idx]; - float c_orig = dst_tile[dst_idx]; + return; + } - dst_tile[dst_idx] = Data[BIdx].alpha * ab + Data[BIdx].beta * c_orig; + float alpha = Data[BIdx].alpha; + float beta = Data[BIdx].beta; + size_t ldc = Data[BIdx].ldc; + + for (size_t i = 0; i < TileSizeM; ++i) { + for (size_t j = 0; j < TileSizeN; ++j) { + const size_t temp_idx = i * TileSizeN + j; + const size_t dst_idx = i * ldc + j; + + float ab = temp_tile[temp_idx]; + float c_orig = dst_tile[dst_idx]; + + if (alpha == 1.0f && beta == 0.0f) { + dst_tile[dst_idx] = ab; + } else if (alpha == 1.0f) { + dst_tile[dst_idx] = ab + beta * c_orig; + } else if (beta == 0.0f) { + dst_tile[dst_idx] = alpha * ab; + } else { + dst_tile[dst_idx] = alpha * ab + beta * c_orig; } } } + return; }); + return true; + } + else { + return false; } - return true; } diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 4bcf71335d15e..06c3628eb301d 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1266,17 +1266,16 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe // the single operator operation mode of CANN if (info_.enable_cann_graph) { std::vector&& unsupported_nodes = SupportONNXModel(graph_viewer); - - if (unsupported_nodes.empty()) { - auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); - } else { + if (info_.enable_cann_subgraph && !unsupported_nodes.empty()) { auto partitions = GetSubGraphPartition(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); for (const auto& partition : partitions) { auto sub_graph = GetSubGraph(partition, graph_viewer); result.push_back(ComputeCapability::Create(std::move(sub_graph))); } + } else { + auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer); + result.push_back(ComputeCapability::Create(std::move(sub_graph))); } } else { InlinedVector candidates; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider_info.cc b/onnxruntime/core/providers/cann/cann_execution_provider_info.cc index d1ba7544bc09e..d6cf9fad70ae5 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider_info.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider_info.cc @@ -20,6 +20,7 @@ constexpr const char* kDeviceId = "device_id"; constexpr const char* kMemLimit = "npu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kEnableCannGraph = "enable_cann_graph"; +constexpr const char* kEnableCannSubGraph = "enable_cann_subgraph"; constexpr const char* kDumpGraphs = "dump_graphs"; constexpr const char* kDumpOmModel = "dump_om_model"; constexpr const char* kPrecisionMode = "precision_mode"; @@ -58,6 +59,7 @@ CANNExecutionProviderInfo CANNExecutionProviderInfo::FromProviderOptions(const P cann::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) .AddAssignmentToReference(cann::provider_option_names::kEnableCannGraph, info.enable_cann_graph) + .AddAssignmentToReference(cann::provider_option_names::kEnableCannSubGraph, info.enable_cann_subgraph) .AddAssignmentToReference(cann::provider_option_names::kDumpGraphs, info.dump_graphs) .AddAssignmentToReference(cann::provider_option_names::kDumpOmModel, info.dump_om_model) .AddAssignmentToReference(cann::provider_option_names::kPrecisionMode, info.precision_mode) @@ -74,6 +76,7 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const CANNExecution {cann::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}, + {cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)}, {cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)}, {cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)}, {cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)}, @@ -89,6 +92,7 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const OrtCANNProvid {cann::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, ArenaExtendStrategy(info.arena_extend_strategy))}, {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}, + {cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)}, {cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)}, {cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)}, {cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)}, diff --git a/onnxruntime/core/providers/cann/cann_execution_provider_info.h b/onnxruntime/core/providers/cann/cann_execution_provider_info.h index 7ac43e9a8ed6f..9c1f9eb03b67e 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider_info.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider_info.h @@ -18,6 +18,7 @@ struct CANNExecutionProviderInfo { size_t npu_mem_limit{std::numeric_limits::max()}; ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; bool enable_cann_graph{true}; + bool enable_cann_subgraph{false}; bool dump_graphs{false}; bool dump_om_model{true}; std::string precision_mode; diff --git a/onnxruntime/core/providers/cann/cann_provider_factory.cc b/onnxruntime/core/providers/cann/cann_provider_factory.cc index 4a130b9b0ca20..d3dc86f588f1d 100644 --- a/onnxruntime/core/providers/cann/cann_provider_factory.cc +++ b/onnxruntime/core/providers/cann/cann_provider_factory.cc @@ -76,6 +76,7 @@ struct CANN_Provider : Provider { info.npu_mem_limit = params->npu_mem_limit; info.arena_extend_strategy = params->arena_extend_strategy; info.enable_cann_graph = params->enable_cann_graph != 0; + info.enable_cann_subgraph = params->enable_cann_subgraph != 0; info.dump_graphs = params->dump_graphs != 0; info.dump_om_model = params->dump_om_model != 0; info.precision_mode = params->precision_mode; @@ -94,6 +95,7 @@ struct CANN_Provider : Provider { cann_options.npu_mem_limit = internal_options.npu_mem_limit; cann_options.arena_extend_strategy = internal_options.arena_extend_strategy; cann_options.enable_cann_graph = internal_options.enable_cann_graph; + cann_options.enable_cann_subgraph = internal_options.enable_cann_subgraph; cann_options.dump_graphs = internal_options.dump_graphs; cann_options.dump_om_model = internal_options.dump_om_model; cann_options.precision_mode = internal_options.precision_mode; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index b7997ce86737a..93b673f2df5bd 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -20,7 +20,6 @@ #include "onnx_ctx_model_helper.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/cuda_graph.h" -#include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" #include "core/common/parse_string.h" @@ -85,40 +84,6 @@ struct ShutdownProtobuf { namespace onnxruntime { -namespace cuda { -template <> -void Impl_Cast( - cudaStream_t stream, - const int64_t* input_data, int32_t* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} - -template <> -void Impl_Cast( - cudaStream_t stream, - const int32_t* input_data, int64_t* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} - -template <> -void Impl_Cast( - cudaStream_t stream, - const double* input_data, float* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} - -template <> -void Impl_Cast( - cudaStream_t stream, - const float* input_data, double* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} -} // namespace cuda - void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr @@ -372,51 +337,19 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(); \ - skip_input_binding_allowed = false; \ - if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ - data = scratch_buffers.back().get(); \ - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ - } else { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - data = scratch_buffers.back().get(); \ - } \ - break; \ - } - #define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ case DATA_TYPE: { \ auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ data_ptr = output_tensor_ptr; \ if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - buffers[output_name] = output_tensor_ptr; \ + buffer = output_tensor_ptr; \ } else { \ scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - buffers[output_name] = scratch_buffers.back().get(); \ + buffer = scratch_buffers.back().get(); \ } \ break; \ } -#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ - case DATA_TYPE: { \ - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ - data_ptr = output_tensor_ptr; \ - skip_output_binding_allowed = false; \ - if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ - buffers[output_name] = scratch_buffers.back().get(); \ - output_dim_sizes[i] = static_cast(elem_cnt); \ - } else { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - buffers[output_name] = scratch_buffers.back().get(); \ - output_dim_sizes[i] = 1; \ - } \ - break; \ - } - #define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ case DATA_TYPE: { \ auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ @@ -426,15 +359,6 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(); \ - if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ - } \ - break; \ - } - /* * Set Nv executio context input. * @@ -557,7 +481,6 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); @@ -582,8 +505,6 @@ Status BindContextInput(Ort::KernelContext& ctx, * param output_type - Data type of the output * param i - Output iteration index * param output_tensors - Output iteration index to output's ORT value - * param output_dim_sizes - Output iteration index to the multiplocation of its shape's dimensions - * param dds_output_set - DDS output set * param dds_output_allocator_map - DDS output to its allocator * param scratch_buffer - The allocation buffer created by TRT EP * param allocator - ORT allocator @@ -595,16 +516,11 @@ Status BindContextOutput(Ort::KernelContext& ctx, const char* output_name, size_t output_index, size_t output_type, - size_t i, - std::unordered_map& output_tensors, - std::unordered_map& output_dim_sizes, DDSOutputAllocatorMap& dds_output_allocator_map, std::vector>& scratch_buffers, OrtAllocator* alloc, - std::unordered_map& buffers, nvinfer1::Dims& dims, - void*& data_ptr, - bool& skip_output_binding_allowed) { + void*& data_ptr) { // Get output shape dims = trt_context->getTensorShape(output_name); int nb_dims = dims.nbDims; @@ -634,10 +550,11 @@ Status BindContextOutput(Ort::KernelContext& ctx, data_ptr = nullptr; // Set data_ptr to nullptr for DDS output binding. } } else { - output_tensors[i] = ctx.GetOutput(output_index, dims.d, nb_dims); - auto& output_tensor = output_tensors[i]; + auto output_tensor = ctx.GetOutput(output_index, dims.d, nb_dims); const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + void* buffer = nullptr; + switch (output_type) { // below macros set data_ptr and skip_output_binding_allowed variables CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) @@ -648,13 +565,12 @@ Status BindContextOutput(Ort::KernelContext& ctx, CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } - trt_context->setTensorAddress(output_name, buffers[output_name]); + trt_context->setTensorAddress(output_name, buffer); } return Status::OK(); @@ -711,7 +627,6 @@ Status BindKernelOutput(Ort::KernelContext& ctx, CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -2837,7 +2752,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } // Save TRT engine, other TRT objects and input/output info to map - parsers_.emplace(fused_node.Name(), std::move(trt_parser)); engines_.emplace(fused_node.Name(), std::move(trt_engine)); contexts_.emplace(fused_node.Name(), std::move(trt_context)); networks_.emplace(fused_node.Name(), std::move(trt_network)); @@ -2853,7 +2767,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), - &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], + &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, engine_cache_enable_, cache_path_, @@ -2891,7 +2805,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; - int num_outputs = static_cast(output_indexes.size()); std::unordered_set input_names; if (alloc_ == nullptr) { @@ -2966,16 +2879,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr /* * Set output shapes and bind output buffers */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); - if (require_io_binding) { - bool skip_output_binding_allowed = true; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -2993,16 +2897,15 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr nvinfer1::Dims dims; void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed); + + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, + dds_output_allocator_map, scratch_buffers, alloc, dims, data_ptr); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims}; } - - trt_state->skip_io_binding_allowed = trt_state->skip_io_binding_allowed | skip_output_binding_allowed; } // Set execution context memory @@ -3082,14 +2985,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } - } else { - auto& output_tensor = output_tensors[i]; - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } } } @@ -3213,7 +3108,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); - int num_outputs = static_cast(output_indexes.size()); std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input @@ -3283,16 +3177,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra /* * Set output shapes and bind output buffers */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); - if (require_io_binding) { - bool skip_output_binding_allowed = true; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3311,16 +3196,14 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra nvinfer1::Dims dims; void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed); + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, + dds_output_allocator_map, scratch_buffers, alloc, dims, data_ptr); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims}; } - - trt_state->skip_io_binding_allowed = trt_state->skip_io_binding_allowed | skip_output_binding_allowed; } // Set execution context memory @@ -3401,14 +3284,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } - } else { - auto& output_tensor = output_tensors[i]; - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } } } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 22b8314649757..9e5fd03756f02 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -195,7 +195,6 @@ struct TensorrtFuncState { AllocatorHandle allocator = nullptr; std::string fused_node_name; nvinfer1::IBuilder* builder; - tensorrt_ptr::unique_pointer* parser = nullptr; std::unique_ptr* engine = nullptr; std::unique_ptr* context = nullptr; std::unique_ptr* network = nullptr; @@ -386,7 +385,6 @@ class NvExecutionProvider : public IExecutionProvider { // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. - std::unordered_map> parsers_; std::unordered_map> engines_; std::unordered_map> contexts_; std::unordered_map> builders_; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 541ca5ca7ab14..a994c936970f6 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -245,6 +245,12 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, bool is_graph_input = qnn_model_wrapper.IsGraphInput(input1_name); LOGS(logger, VERBOSE) << "Add HWCN Transpose node after input: " << input1_name; + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(input1_name)) { + QnnTensorWrapper weight_tensor_wrapper; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], weight_tensor_wrapper)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor_wrapper)), "Failed to add weight tensor."); + } + if (conv_type == OnnxConvType::kConv) { ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddNchwToHwcnTranspose(node_unit.Index(), input1_name, @@ -425,7 +431,7 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, // // Input 1: weight - // We need to first reshape the weight inorder to handle 1D convolutions with the Conv2d operator. + // We need to first reshape the weight in order to handle 1D convolutions with the Conv2d operator. // Next, we have to transpose the weight because ORT layout transformations do not change the weight layout. // { @@ -511,6 +517,12 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF(input_info.quant_param.IsPerChannel(), "Non-constant Conv inputs only support per-tensor quantization"); + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(input1_name)) { + QnnTensorWrapper weight_tensor_wrapper; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], weight_tensor_wrapper)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor_wrapper)), "Failed to add weight tensor."); + } + bool is_graph_input = qnn_model_wrapper.IsGraphInput(input1_name); LOGS(logger, VERBOSE) << "Adding Reshape (to 2D) and HWCN Transpose node after input: " << input1_name; ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(input1_name, diff --git a/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc b/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc index 9fa2551e53c23..f8d88b07f6dd5 100644 --- a/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc +++ b/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc @@ -24,7 +24,7 @@ std::once_flag init; } // namespace void InitProviderOrtApi() { - std::call_once(init, []() { Ort::Global::api_ = Provider_GetHost()->OrtGetApiBase()->GetApi(ORT_API_VERSION); }); + std::call_once(init, []() { Ort::InitApi(Provider_GetHost()->OrtGetApiBase()->GetApi(ORT_API_VERSION)); }); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 5fc0b8900730b..580fbfbdba0b0 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -229,7 +229,7 @@ int vitisai_ep_set_ep_dynamic_options( struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { op_kernel_ = - op_.CreateKernel(&op_, Ort::Global::api_, reinterpret_cast(&info)); + op_.CreateKernel(&op_, &Ort::GetApi(), reinterpret_cast(&info)); } ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } @@ -332,8 +332,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { InitProviderOrtApi(); set_version_info(the_global_api); the_global_api.host_ = Provider_GetHost(); - assert(Ort::Global::api_ != nullptr); - the_global_api.ort_api_ = Ort::Global::api_; + assert(&Ort::GetApi() != nullptr); + the_global_api.ort_api_ = &Ort::GetApi(); the_global_api.model_load = [](const std::string& filename) -> Model* { auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); auto& logger = logging::LoggingManager::DefaultLogger(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index ad0a1ad137f06..f3e2a8ce7ba7b 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3036,7 +3036,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { +ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { API_IMPL_BEGIN if (attr_tensor == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); @@ -3045,7 +3045,39 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); } - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); + const auto* attr_proto = reinterpret_cast(attribute); + + if (attr_proto->type() != onnx::AttributeProto::TENSOR) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "This OrtOpAttr instance is not a 'TENSOR' attribute"); + } + + const auto& tensor_proto = attr_proto->t(); + + // Check that TensorProto is valid. + if (!utils::HasDataType(tensor_proto)) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto doesn't have data type."); + } + + if (!ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto has invalid data type."); + } + + if (utils::HasExternalData(tensor_proto)) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Tensor proto with external data for value attribute is not supported."); + } + + // Initialize OrtValue for tensor attribute. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + // The tensor in the 'Tensor' attribute's TensorProto is stored inline, not in an external file. + // Therefore, the 'model_path' passed to TensorProtoToOrtValue() may be an empty path. + std::filesystem::path model_path; + ORT_API_RETURN_IF_STATUS_NOT_OK(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + *attr_tensor = tensor_attribute_value.release(); + return nullptr; API_IMPL_END } @@ -4134,7 +4166,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumAttributes, &OrtApis::Node_GetAttributes, &OrtApis::Node_GetAttributeByName, - &OrtApis::Node_GetTensorAttributeAsOrtValue, + &OrtApis::OpAttr_GetTensorAttributeAsOrtValue, &OrtApis::OpAttr_GetType, &OrtApis::OpAttr_GetName, &OrtApis::Node_GetNumSubgraphs, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index e62149d04a16c..6dc4cf9d195cc 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -687,7 +687,7 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_result_maybenull_ const OrtOpAttr** attribute); -ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, +ORT_API_STATUS_IMPL(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc index d6e51a44c1c69..42b65239de92c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc @@ -4,6 +4,8 @@ #include "core/session/plugin_ep/ep_factory_provider_bridge.h" #include "core/providers/shared_library/provider_host_api.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" namespace onnxruntime { OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, @@ -20,6 +22,11 @@ OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_fa auto* ep_device = ep_devices[i]; if (ep_device) { ep_device->ep_factory = &ep_factory; + + // Add library path to EP metadata if available + if (library_path_.has_value()) { + ep_device->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_LibraryPath, library_path_->string()); + } } } diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 437af62dc2c0c..8c5ef526baba1 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" @@ -12,12 +16,14 @@ namespace onnxruntime { class ProviderBridgeEpFactory : public EpFactoryInternalImpl { public: - ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library, + std::optional library_path = std::nullopt) : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), ep_factory.GetVendor(&ep_factory), ep_factory.GetVendorId(&ep_factory)), ep_factory_{ep_factory}, - provider_library_{provider_library} { + provider_library_{provider_library}, + library_path_{std::move(library_path)} { } private: @@ -59,8 +65,9 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); } - OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP - ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP + OrtEpFactory& ep_factory_; + ProviderLibrary& provider_library_; + std::optional library_path_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h index 24ab74e1c77fc..af5bc23143e33 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library.h +++ b/onnxruntime/core/session/plugin_ep/ep_library.h @@ -23,6 +23,7 @@ class EpLibrary { virtual Status Load() { return Status::OK(); } virtual const std::vector& GetFactories() = 0; // valid after Load() virtual Status Unload() { return Status::OK(); } + virtual ~EpLibrary() = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(EpLibrary); diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc index 06cf54aea4071..da94a9f12ba9d 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc @@ -4,6 +4,7 @@ #include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/plugin_ep/ep_factory_provider_bridge.h" +#include "core/session/plugin_ep/ep_library_plugin.h" namespace onnxruntime { Status EpLibraryProviderBridge::Load() { @@ -26,8 +27,9 @@ Status EpLibraryProviderBridge::Load() { // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. + for (const auto& factory : ep_library_plugin_->GetFactories()) { - auto factory_impl = std::make_unique(*factory, *provider_library_); + auto factory_impl = std::make_unique(*factory, *provider_library_, library_path_); auto internal_factory = std::make_unique(std::move(factory_impl)); factory_ptrs_.push_back(internal_factory.get()); diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h index c7e8ebefc3785..45277b2828f56 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h @@ -21,9 +21,11 @@ namespace onnxruntime { class EpLibraryProviderBridge : public EpLibrary { public: EpLibraryProviderBridge(std::unique_ptr provider_library, - std::unique_ptr ep_library_plugin) + std::unique_ptr ep_library_plugin, + std::optional library_path = std::nullopt) : provider_library_{std::move(provider_library)}, - ep_library_plugin_{std::move(ep_library_plugin)} { + ep_library_plugin_{std::move(ep_library_plugin)}, + library_path_{std::move(library_path)} { } const char* RegistrationName() const override { @@ -53,6 +55,9 @@ class EpLibraryProviderBridge : public EpLibrary { // implement EpFactoryInternal::CreateIExecutionProvider by calling Provider::CreateIExecutionProvider. std::unique_ptr ep_library_plugin_; + // Library path for EP metadata + std::optional library_path_; + std::vector> factories_; std::vector factory_ptrs_; // for convenience std::vector internal_factory_ptrs_; // for convenience diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 41cf8be1d1412..f82cbcf63ca62 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2902,6 +2902,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateCANNProviderOptions, _Outptr_ OrtCANNProvider options->npu_mem_limit = SIZE_MAX; options->arena_extend_strategy = static_cast(0); options->enable_cann_graph = 1; + options->enable_cann_subgraph = 0; options->dump_graphs = 0; options->dump_om_model = 1; options->default_memory_arena_cfg = nullptr; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index d4041dfce5a7a..7da7fabb15b15 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -421,13 +421,14 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, << (is_provider_bridge ? " as a provider bridge" : " as a plugin"); // create EpLibraryPlugin to ensure CreateEpFactories and ReleaseEpFactory are available - auto ep_library_plugin = std::make_unique(registration_name, std::move(resolved_library_path)); + auto ep_library_plugin = std::make_unique(registration_name, resolved_library_path); ORT_RETURN_IF_ERROR(ep_library_plugin->Load()); if (is_provider_bridge) { // wrap the EpLibraryPlugin with EpLibraryProviderBridge to add to directly create an IExecutionProvider auto ep_library_provider_bridge = std::make_unique(std::move(provider_library), - std::move(ep_library_plugin)); + std::move(ep_library_plugin), + resolved_library_path); ORT_RETURN_IF_ERROR(ep_library_provider_bridge->Load()); internal_factories = ep_library_provider_bridge->GetInternalFactories(); ep_library = std::move(ep_library_provider_bridge); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 24554560b4dde..eb06a65ad5330 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1575,6 +1575,17 @@ void addGlobalMethods(py::module& m) { R"pbdoc(Get the list of available OrtEpDevice instances.)pbdoc", py::return_value_policy::reference); + m.def( + "get_model_compatibility_for_ep_devices", + [](const std::vector& ep_devices, + const std::string& compatibility_info) -> OrtCompiledModelCompatibility { + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + Ort::ThrowOnError(Ort::GetApi().GetModelCompatibilityForEpDevices( + ep_devices.data(), ep_devices.size(), compatibility_info.c_str(), &status)); + return status; + }, + R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc"); + #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( "get_available_openvino_device_ids", []() -> std::vector { @@ -1759,6 +1770,12 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED) .value("MEMORY_EFFICIENT", ExecutionOrder::MEMORY_EFFICIENT); + py::enum_(m, "OrtCompiledModelCompatibility") + .value("EP_NOT_APPLICABLE", OrtCompiledModelCompatibility_EP_NOT_APPLICABLE) + .value("EP_SUPPORTED_OPTIMAL", OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL) + .value("EP_SUPPORTED_PREFER_RECOMPILATION", OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION) + .value("EP_UNSUPPORTED", OrtCompiledModelCompatibility_EP_UNSUPPORTED); + py::enum_(m, "OrtAllocatorType") .value("INVALID", OrtInvalidAllocator) .value("ORT_DEVICE_ALLOCATOR", OrtDeviceAllocator) @@ -1782,7 +1799,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra type = OrtDevice::GPU; vendor = OrtDevice::VendorIds::MICROSOFT; } else if (type == OrtDevice::GPU) { -#if USE_CUDA +#if USE_CUDA || USE_NV || USE_NV_PROVIDER_INTERFACE || USE_CUDA_PROVIDER_INTERFACE vendor = OrtDevice::VendorIds::NVIDIA; #elif USE_ROCM || USE_MIGRAPHX vendor = OrtDevice::VendorIds::AMD; diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index 191edc4c6390d..a12aca47f5b65 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -6,15 +6,15 @@ from __future__ import annotations import logging +import tempfile from pathlib import Path import onnx -from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed +from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed, optimize_model from ....tools.remove_initializer_from_input import remove_initializer_from_input from ...fusions import FusionGelu, FusionLayerNormalization from ...onnx_model import ONNXModel -from ...quant_utils import save_and_reload_model_with_shape_infer from .fusion_lpnorm import FusionLpNormalization from .fusion_spacetodepth import FusionSpaceToDepth @@ -93,7 +93,7 @@ def qnn_preprocess_model( """ modified = False model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input) - model = save_and_reload_model_with_shape_infer(model) + model = save_and_reload_optimize_model(model, shape_infer=True) onnx_model = ONNXModel(model) # Optionally, fix the dynamic input shapes. @@ -178,6 +178,24 @@ def qnn_preprocess_model( return modified +def save_and_reload_optimize_model(model: onnx.ModelProto, shape_infer: bool) -> onnx.ModelProto: + with tempfile.TemporaryDirectory(prefix="ort.qnn_preproc.") as qnn_preproc_tmp_dir: + model_in_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_input.onnx") + onnx.save_model(model, model_in_path, save_as_external_data=True) + if shape_infer: + model_infer_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_infer.onnx") + onnx.shape_inference.infer_shapes_path(str(model_in_path), str(model_infer_path)) + model_in_path = model_infer_path + model_out_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_output.onnx") + optimize_model(model_in_path, model_out_path) + ret_model = onnx.load_model(model_out_path) + ret_metaprops = {"onnx.infer": "onnxruntime.tools.qnn.preprocess"} + if ret_model.metadata_props: + ret_metaprops.update(ret_model.metadata_props) + onnx.helper.set_model_props(ret_model, ret_metaprops) + return ret_model + + class InputOutputNameMap: def __init__( self, diff --git a/onnxruntime/test/autoep/library/ep_arena.h b/onnxruntime/test/autoep/library/ep_arena.h index 641f3ce3f7b17..caa2c61db835f 100644 --- a/onnxruntime/test/autoep/library/ep_arena.h +++ b/onnxruntime/test/autoep/library/ep_arena.h @@ -21,7 +21,10 @@ limitations under the License. #include #include +#define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + #include "ep_allocator.h" #include "example_plugin_ep_utils.h" diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index ed7ca998e0b86..0690b8894eb7a 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -144,6 +144,12 @@ static void RunQMoETest(const std::vector& input, const std::vector("k", static_cast(top_k)); cpu_tester.AddAttribute("activation_type", activation_type); @@ -1323,6 +1329,13 @@ TEST(MoETest, QMoETest_Mixtral_Int4) { // CPU-specific QMoE tests TEST(MoETest, QMoETest_CPU_Int4_MLAS) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + int num_rows = 2; int num_experts = 2; int hidden_size = 32; @@ -1387,9 +1400,19 @@ TEST(MoETest, QMoETest_CPU_Int4_MLAS) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif } TEST(MoETest, QMoETest_CPU_Int8_MLAS) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + // Test CPU implementation with 8-bit quantization - CPU ONLY int num_rows = 1; int num_experts = 2; @@ -1446,9 +1469,19 @@ TEST(MoETest, QMoETest_CPU_Int8_MLAS) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif } TEST(MoETest, QMoETest_CPU_FC3_Error) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + // Test that CPU throws error when FC3 gating is provided - CPU ONLY int num_rows = 1; int num_experts = 2; @@ -1506,9 +1539,19 @@ TEST(MoETest, QMoETest_CPU_FC3_Error) { // Expect this to fail with FC3 not implemented error cpu_tester.Run(OpTester::ExpectResult::kExpectFailure, "FC3 gating is not yet implemented", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif } TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + // Test CPU implementation with 4-bit quantization and SwiGLU activation int num_rows = 2; int num_experts = 2; @@ -1573,9 +1616,18 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif } TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } // Test CPU implementation with 8-bit quantization and SwiGLU activation int num_rows = 1; int num_experts = 2; @@ -1633,6 +1685,9 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif } #endif diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc index ee82d4683ab73..a8a83fbe5ceb6 100644 --- a/onnxruntime/test/framework/ep_compatibility_test.cc +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -15,6 +15,7 @@ #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/utils.h" #include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" #include "core/session/abi_session_options_impl.h" #include "core/framework/error_code_helper.h" #include "dummy_provider.h" @@ -499,3 +500,31 @@ TEST(EpCompatibilityCapiTest, CpuEpReturnsNotApplicableIfNoValidation) { api->ReleaseEnv(env); } + +// ----------------------------- +// C++ API unit tests +// ----------------------------- + +TEST(EpCompatibilityCxxApiTest, SingleDeviceCpuProvider) { + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpCompatCxx"}; + auto devices = env.GetEpDevices(); + ASSERT_FALSE(devices.empty()); + + std::vector selected; + for (const auto& d : devices) { + if (std::string{d.EpName()} == "CPUExecutionProvider") { + selected.push_back(d); + break; + } + } + + ASSERT_FALSE(selected.empty()); + + // Pick a status that the CPU EP would never return to ensure the value is set correctly. + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + ASSERT_NO_FATAL_FAILURE({ + status = Ort::GetModelCompatibilityForEpDevices(selected, "arbitrary-compat-string"); + }); + + ASSERT_TRUE(status == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); +} \ No newline at end of file diff --git a/onnxruntime/test/platform/device_discovery_test.cc b/onnxruntime/test/platform/device_discovery_test.cc index 21ddf9a5b1cd7..6b43ccbc8f670 100644 --- a/onnxruntime/test/platform/device_discovery_test.cc +++ b/onnxruntime/test/platform/device_discovery_test.cc @@ -25,9 +25,9 @@ TEST(DeviceDiscoveryTest, HasCpuDevice) { const auto cpu_devices = GetDevicesByType(OrtHardwareDeviceType_CPU); ASSERT_GT(cpu_devices.size(), 0); -#if !defined(__wasm__) +#if defined(CPUINFO_SUPPORTED) ASSERT_NE(cpu_devices[0].vendor_id, 0); -#endif // !defined(__WASM__) +#endif // defined(CPUINFO_SUPPORTED) } } // namespace onnxruntime::test diff --git a/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py b/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py new file mode 100644 index 0000000000000..8e69fdf088103 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import platform +import sys +import unittest + +from onnxruntime.capi.onnxruntime_pybind11_state import ( + OrtCompiledModelCompatibility, + get_ep_devices, + get_model_compatibility_for_ep_devices, +) + +# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. +if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 + os.add_dll_directory(os.getcwd()) + + +class TestEpCompatibility(unittest.TestCase): + def test_invalid_args(self): + # empty devices + with self.assertRaises(RuntimeError): + get_model_compatibility_for_ep_devices([], "info") + # None compatibility info should raise TypeError before native call + with self.assertRaises(TypeError): + get_model_compatibility_for_ep_devices(get_ep_devices(), None) # type: ignore[arg-type] + + def test_basic_smoke(self): + devices = list(get_ep_devices()) + if not devices: + self.skipTest("No EP devices available in this build") + + # Always select CPUExecutionProvider; skip if not present. + cpu_devices = [d for d in devices if getattr(d, "ep_name", None) == "CPUExecutionProvider"] + if not cpu_devices: + self.skipTest("CPUExecutionProvider not available in this build") + selected = [cpu_devices[0]] + + # API requires all devices belong to the same EP; we pass only one. + status = get_model_compatibility_for_ep_devices(selected, "arbitrary-compat-string") + self.assertEqual(status, OrtCompiledModelCompatibility.EP_NOT_APPLICABLE) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py b/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py new file mode 100644 index 0000000000000..d5c80a4a1f4ba --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py @@ -0,0 +1,468 @@ +# Copyright (c) NVIDIA Corporation. All rights reserved. +# Licensed under the MIT License. +from __future__ import annotations + +import sys +import unittest +from collections.abc import Sequence + +import numpy as np +import torch +from autoep_helper import AutoEpTestCase +from helper import get_name +from numpy.testing import assert_almost_equal +from onnx import TensorProto, helper +from onnx.defs import onnx_opset_version + +import onnxruntime as onnxrt +from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice +from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue +from onnxruntime.capi._pybind_state import OrtValueVector, SessionIOBinding + + +class TestNvTensorRTRTXAutoEP(AutoEpTestCase): + """ + Test suite for the NvTensorRTRTX Execution Provider. + + This class contains tests for registering the NvTensorRTRTX EP, + selecting it using different policies, and running inference with various + I/O binding configurations. + """ + + ep_lib_path = "onnxruntime_providers_nv_tensorrt_rtx.dll" + ep_name = "NvTensorRTRTXExecutionProvider" + + def setUp(self): + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + self.register_execution_provider_library(self.ep_name, self.ep_lib_path) + + def tearDown(self): + self.unregister_execution_provider_library(self.ep_name) + + def _create_ortvalue_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), device, 0 + ) + + def _create_ortvalue_alternate_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]], dtype=np.float32), + device, + 0, + ) + + def _create_uninitialized_ortvalue_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, device, 0) + + def _create_numpy_input(self): + return np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + + def _create_expected_output(self): + return np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + + def _create_expected_output_alternate(self): + return np.array([[2.0, 8.0], [18.0, 32.0], [50.0, 72.0]], dtype=np.float32) + + def torch_to_onnx_type(self, torch_dtype): + if torch_dtype == torch.float32: + return TensorProto.FLOAT + elif torch_dtype == torch.float16: + return TensorProto.FLOAT16 + elif torch_dtype == torch.bfloat16: + return TensorProto.BFLOAT16 + elif torch_dtype == torch.int8: + return TensorProto.int8 + elif torch_dtype == torch.int32: + return TensorProto.INT32 + elif torch_dtype == torch.int64: + return TensorProto.INT64 + else: + raise TypeError(f"Unsupported dtype: {torch_dtype}") + + def test_nv_tensorrt_rtx_ep_register_and_inference(self): + """ + Test registration of NvTensorRTRTX EP, adding its OrtDevice to the SessionOptions, and running inference. + """ + ep_devices = onnxrt.get_ep_devices() + nv_tensorrt_rtx_ep_device = next((d for d in ep_devices if d.ep_name == self.ep_name), None) + self.assertIsNotNone(nv_tensorrt_rtx_ep_device) + self.assertEqual(nv_tensorrt_rtx_ep_device.ep_vendor, "NVIDIA") + + hw_device = nv_tensorrt_rtx_ep_device.device + self.assertEqual(hw_device.type, onnxrt.OrtHardwareDeviceType.GPU) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx")) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_nv_tensorrt_rtx_ep_prefer_gpu_and_inference(self): + """ + Test selecting NvTensorRTRTX EP via the PREFER_GPU policy and running inference. + """ + # Set a policy to prefer GPU. NvTensorRTRTX should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_nv_tensorrt_rtx_ep_selection_delegate_and_inference(self): + """ + Test selecting NvTensorRTRTX EP via the custom EP selection delegate function and then run inference. + """ + + # User's custom EP selection function. + def my_delegate( + ep_devices: Sequence[onnxrt.OrtEpDevice], + model_metadata: dict[str, str], + runtime_metadata: dict[str, str], + max_selections: int, + ) -> Sequence[onnxrt.OrtEpDevice]: + self.assertGreater(len(model_metadata), 0) + self.assertGreaterEqual(len(ep_devices), 1) + self.assertGreaterEqual(max_selections, 2) + + nv_tensorrt_rtx_ep_device = next((d for d in ep_devices if d.ep_name == self.ep_name), None) + self.assertIsNotNone(nv_tensorrt_rtx_ep_device) + + # Select the NvTensorRTRTX device + return [nv_tensorrt_rtx_ep_device] + + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy_delegate(my_delegate) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_bind_input_only(self): + """ + Test I/O binding with input data only. + """ + # Set a policy to prefer GPU. NvTensorRTRTX should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + input = self._create_ortvalue_input_on_gpu("cuda") + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Bind output to CPU + io_binding.bind_output("Y") + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output = io_binding.copy_outputs_to_cpu()[0] + + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output)) + + def test_bind_input_and_bind_output_with_ortvalues(self): + """ + Test I/O binding with OrtValues for both input and output. + """ + # Set a policy to prefer GPU. NvTensorRTRTX EP should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind ortvalue as input + input_ortvalue = self._create_ortvalue_input_on_gpu("cuda") + io_binding.bind_ortvalue_input("X", input_ortvalue) + + # Bind ortvalue as output + output_ortvalue = self._create_uninitialized_ortvalue_input_on_gpu("cuda") + io_binding.bind_ortvalue_output("Y", output_ortvalue) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output(), output_ortvalue.numpy())) + + # Bind another ortvalue as input + input_ortvalue_2 = self._create_ortvalue_alternate_input_on_gpu("cuda") + io_binding.bind_ortvalue_input("X", input_ortvalue_2) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), output_ortvalue.numpy())) + + def test_bind_input_and_non_preallocated_output(self): + """ + Test I/O binding with non-preallocated output. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + input = self._create_ortvalue_input_on_gpu("cuda") + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Bind output to the GPU + io_binding.bind_output("Y", "cuda") + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + # We should be able to repeat the above process as many times as we want - try once more + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + input = self._create_ortvalue_alternate_input_on_gpu("cuda") + + # Change the bound input and validate the results in the same bound OrtValue + # Bind alternate input to the GPU + io_binding.bind_input( + "X", + "cuda", + 0, + np.float32, + [3, 2], + input.data_ptr(), + ) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), ort_outputs[0].numpy())) + + def test_bind_input_and_preallocated_output(self): + """ + Test I/O binding with preallocated output. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + input = self._create_ortvalue_input_on_gpu("cuda") + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Bind output to the GPU + output = self._create_uninitialized_ortvalue_input_on_gpu("cuda") + io_binding.bind_output("Y", "cuda", 0, np.float32, [3, 2], output.data_ptr()) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output_vals = io_binding.copy_outputs_to_cpu()[0] + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals)) + + # Validate if ORT actually wrote to pre-allocated buffer by copying the allocated buffer + # to the host and validating its contents + ort_output_vals_in_cpu = output.numpy() + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals_in_cpu)) + + def test_bind_input_types(self): + """ + Test I/O binding with various input data types. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + opset = onnx_opset_version() + device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + + for dtype in [ + np.float32, + # np.float64, + np.int32, + # np.uint32, + np.int64, + # np.uint64, + # np.int16, + # np.uint16, + # np.int8, + np.uint8, + np.float16, + np.bool_, + ]: + with self.subTest(dtype=dtype, inner_device=str(device)): + x = np.arange(8).reshape((-1, 2)).astype(dtype) + proto_dtype = helper.np_dtype_to_tensor_dtype(x.dtype) + + X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806 + Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806 + + # inference + node_add = helper.make_node("Identity", ["X"], ["Y"]) + + # graph + graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=7, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + + sess = onnxrt.InferenceSession(model_def.SerializeToString(), sess_options=sess_options) + + bind = SessionIOBinding(sess._sess) + ort_value = C_OrtValue.ortvalue_from_numpy(x, device) + bind.bind_ortvalue_input("X", ort_value) + bind.bind_output("Y", device) + sess._sess.run_with_iobinding(bind, None) + ortvaluevector = bind.get_outputs() + self.assertIsInstance(ortvaluevector, OrtValueVector) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) + + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", device, dtype, x.shape, ort_value.data_ptr()) + bind.bind_output("Y", device) + sess._sess.run_with_iobinding(bind, None) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) + + def test_bind_onnx_types_from_torch(self): + """ + Test I/O binding with various input data types. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + opset = onnx_opset_version() + + for dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + torch.int32, + torch.int64, + ]: + with self.subTest(dtype=dtype): + proto_dtype = self.torch_to_onnx_type(dtype) + + x_ = helper.make_tensor_value_info("X", proto_dtype, [None]) + y_ = helper.make_tensor_value_info("Y", proto_dtype, [None]) + node_add = helper.make_node("Identity", ["X"], ["Y"]) + graph_def = helper.make_graph([node_add], "lr", [x_], [y_], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=10, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + sess = onnxrt.InferenceSession(model_def.SerializeToString(), sess_options=sess_options) + + dev = "cuda" if torch.cuda.is_available() else "cpu" + device = ( + C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + if dev == "cuda" + else C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) + ) + + x = torch.arange(8, dtype=dtype, device=dev) + y = torch.empty(8, dtype=dtype, device=dev) + + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", device, proto_dtype, x.shape, x.data_ptr()) + bind.bind_output("Y", device, proto_dtype, y.shape, y.data_ptr()) + sess._sess.run_with_iobinding(bind, None) + self.assertTrue(torch.equal(x, y)) + + +if __name__ == "__main__": + unittest.main(verbosity=1) diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index 8ab58adbeeb74..bc22864304567 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -26,7 +26,7 @@ static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { } OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { - Ort::Global::api_ = api->GetApi(ORT_API_VERSION); + Ort::InitApi(api->GetApi(ORT_API_VERSION)); OrtStatus* result = nullptr; ORT_TRY { diff --git a/tools/ci_build/github/windows/extract_nuget_files.ps1 b/tools/ci_build/github/windows/extract_nuget_files.ps1 index ff8f63a85b97a..20d6c1f2b63a5 100644 --- a/tools/ci_build/github/windows/extract_nuget_files.ps1 +++ b/tools/ci_build/github/windows/extract_nuget_files.ps1 @@ -1,105 +1,119 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline +# This file is used by Zip-Nuget-Java Packaging Pipeline -# Re-construct a build directory that contains binaries from all the different platforms we're including -# in the native ORT nuget package +# Define the directory for NuGet artifacts. $nuget_artifacts_dir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" -New-Item -Path $nuget_artifacts_dir -ItemType directory +# Create the directory if it doesn't exist. +New-Item -Path $nuget_artifacts_dir -ItemType directory -ErrorAction SilentlyContinue ## .zip files -# unzip directly -# exclude the iOS xcframework as we need to leave that zipped up to preserve symlinks -Get-ChildItem -Path $Env:BUILD_BINARIESDIRECTORY\nuget-artifact\* -Include *.zip -Exclude onnxruntime_ios_xcframework.*.zip | +# Unzip files directly, excluding the iOS xcframework to preserve its symlinks. +Get-ChildItem -Path "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact\*" -Include *.zip -Exclude onnxruntime_ios_xcframework.*.zip | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$nuget_artifacts_dir" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + # Directly call 7z.exe using the call operator '&' + & 7z.exe $arguments + # Check the exit code of the last command. A non-zero code indicates an error. + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } ## .tgz files -# first extract the tar file from the tgz -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tgz | +# First, extract the .tar file from the .tgz archive. +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tgz | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -# now extract the actual folder structure from the tar file to the build dir -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tar | +# Now, extract the contents from the .tar file into the final directory. +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tar | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$nuget_artifacts_dir" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -# process iOS xcframework -$xcframeworks = Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter onnxruntime_ios_xcframework.*.zip +# Process iOS xcframework +$xcframeworks = Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter onnxruntime_ios_xcframework.*.zip if ($xcframeworks.Count -eq 1) { - $xcframework = $xcframeworks[0] - $target_dir = "$nuget_artifacts_dir\onnxruntime-ios-xcframework" - # remove version info from filename and use required filename format - $target_file = "$target_dir\onnxruntime.xcframework.zip" - New-Item -Path $target_dir -ItemType directory + $xcframework = $xcframeworks[0] + $target_dir = "$nuget_artifacts_dir\onnxruntime-ios-xcframework" + # Use the required filename format, removing version info. + $target_file = "$target_dir\onnxruntime.xcframework.zip" + New-Item -Path $target_dir -ItemType directory -ErrorAction SilentlyContinue - Write-Output "Copy-Item $($xcframework.FullName) $target_file" - Copy-Item $xcframework.FullName $target_file + Write-Output "Copying $($xcframework.FullName) to $target_file" + Copy-Item $xcframework.FullName $target_file } elseif ($xcframeworks.Count -gt 1) { - Write-Error "Expected at most one onnxruntime_ios_xcframework*.zip file but got: [$xcframeworks]" + Write-Error "Expected at most one onnxruntime_ios_xcframework*.zip file but got: [$xcframeworks]" } - -# copy android AAR. -# for full build of onnxruntime Android AAR, there should only be one .aar file -# called onnxruntime-android-x.y.z.aar or onnxruntime-training-android-x.y.z.aar but sanity check that -$aars = Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.aar +# Copy Android AAR file. +# There should only be one .aar file for a full build. +$aars = Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.aar if ($aars.Count -eq 1) { - $aar = $aars[0] - $aar_prefix = "onnxruntime" - if ($aar -like "onnxruntime-training*") { - $aar_prefix = "onnxruntime-training" - } - $target_dir = "$nuget_artifacts_dir\$aar_prefix-android-aar" - $target_file = "$target_dir\onnxruntime.aar" # remove '-mobile' and version info from filename - New-Item -Path $target_dir -ItemType directory + $aar = $aars[0] + $aar_prefix = "onnxruntime" + if ($aar.Name -like "onnxruntime-training*") { + $aar_prefix = "onnxruntime-training" + } + $target_dir = "$nuget_artifacts_dir\$aar_prefix-android-aar" + # Remove version info from the filename for consistency. + $target_file = "$target_dir\onnxruntime.aar" + New-Item -Path $target_dir -ItemType directory -ErrorAction SilentlyContinue - Write-Output "Copy-Item $($aar.FullName) $target_file" - Copy-Item $aar.FullName $target_file + Write-Output "Copying $($aar.FullName) to $target_file" + Copy-Item $aar.FullName $target_file } elseif ($aars.Count -gt 1) { - Write-Error "Expected at most one Android .aar file but got: [$aars]" + Write-Error "Expected at most one Android .aar file but got: [$aars]" } -# Check whether this is a training pipeline -$is_training_pipeline = $false -if (Test-Path -Path $nuget_artifacts_dir\onnxruntime-training-win-x64-*) { - $is_training_pipeline = $true - Write-Output "onnxruntime-training-win-x64-* dir exists. This is a training pipeline." +# Check if this is a training pipeline by looking for a specific directory. +$is_training_pipeline = Test-Path -Path "$nuget_artifacts_dir\onnxruntime-training-win-x64-*" +if ($is_training_pipeline) { + Write-Output "onnxruntime-training-win-x64-* dir exists. This is a training pipeline." } -# Copy onnxruntime and protoc binaries to the binaries dir as these are required -# by Microsoft.ML.OnnxRuntime.Tests.NetCoreApp +# Copy onnxruntime and protoc binaries required by tests. +$destinationDir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo" if ($is_training_pipeline) { - Copy-Item -Path $nuget_artifacts_dir\onnxruntime-training-win-x64-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo + Copy-Item -Path "$nuget_artifacts_dir\onnxruntime-training-win-x64-*\lib\*" -Destination $destinationDir -Recurse } else { - Copy-Item -Path $nuget_artifacts_dir\onnxruntime-win-x64-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo + Copy-Item -Path "$nuget_artifacts_dir\onnxruntime-win-x64-*\lib\*" -Destination $destinationDir -Recurse } -"Get-ChildItem -Directory -Path $nuget_artifacts_dir\onnxruntime-*" -$ort_dirs = Get-ChildItem -Directory -Path $nuget_artifacts_dir\onnxruntime-* -foreach ($ort_dir in $ort_dirs) -{ - # remove the last '-xxx' segment from the dir name. typically that's the architecture. - $dirname = Split-Path -Path $ort_dir -Leaf - $dirname = $dirname.SubString(0,$dirname.LastIndexOf('-')) - Write-Output "Renaming $ort_dir to $dirname" - Rename-Item -Path $ort_dir -NewName $nuget_artifacts_dir\$dirname +# Rename directories to remove the architecture-specific suffix. +Write-Output "Renaming onnxruntime directories..." +Get-ChildItem -Directory -Path "$nuget_artifacts_dir\onnxruntime-*" | ForEach-Object { + $dirname = $_.Name + # Find the last hyphen and remove the suffix. + $lastHyphenIndex = $dirname.LastIndexOf('-') + if ($lastHyphenIndex -gt -1) { + $newName = $dirname.Substring(0, $lastHyphenIndex) + $newPath = Join-Path -Path $_.Parent.FullName -ChildPath $newName + Write-Output "Renaming '$($_.FullName)' to '$newPath'" + Rename-Item -Path $_.FullName -NewName $newName + } } -# List artifacts -"Post copy artifacts" -Get-ChildItem -Recurse $nuget_artifacts_dir\ +# List the final artifacts. +Write-Output "Post-copy artifacts:" +Get-ChildItem -Recurse $nuget_artifacts_dir \ No newline at end of file diff --git a/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 b/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 index 01a8eebe75df2..29946dcb73f8a 100644 --- a/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 +++ b/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 @@ -2,47 +2,81 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget-Java Packaging Pipeline -New-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts -ItemType directory +# Define the directory for NuGet artifacts. +$nuget_artifacts_dir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" +# Create the directory if it doesn't exist. +New-Item -Path $nuget_artifacts_dir -ItemType directory -ErrorAction SilentlyContinue -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.zip | +## .zip files +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.zip | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tgz | +## .tgz files +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tgz | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" # *.tar will be created after *.tgz is extracted - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + # *.tar will be created after *.tgz is extracted + $arguments = "x", "$($_.FullName)", "-y", "-o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tar | +## .tar files +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tar | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } +# Create directory for protobuf build dependencies. +New-Item -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo" -ItemType directory -ErrorAction SilentlyContinue -New-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo -ItemType directory - -Copy-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64-cuda-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo +# Copy CUDA libraries. +Copy-Item -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64-cuda-*\lib\*" -Destination "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo" +# Install protoc via dotnet. $protocInstallDir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build" dotnet new console dotnet add package Google.Protobuf.Tools --version 3.21.12 --package-directory $protocInstallDir +if ($LASTEXITCODE -ne 0) { + throw "Error adding Google.Protobuf.Tools package. Exit code: $LASTEXITCODE" +} + +# Find and copy the protoc executable. $protocDir = Get-ChildItem -Path $protocInstallDir -Recurse -Filter "protoc.exe" | Select-Object -ExpandProperty DirectoryName -First 1 -Write-Output $protocDir -Copy-Item -Path $protocDir -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo - -$ort_dirs = Get-ChildItem -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory -foreach ($ort_dir in $ort_dirs) -{ - $dirname = Split-Path -Path $ort_dir -Leaf - $dirname = $dirname.SubString(0,$dirname.LastIndexOf('-')) - Write-Output "Renaming $ort_dir to $dirname" - Rename-Item -Path $ort_dir -NewName $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname +if ($protocDir) { + Write-Output "Found protoc directory: $protocDir" + Copy-Item -Path $protocDir -Destination "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo" +} +else { + Write-Error "Could not find protoc.exe in $protocInstallDir" } +# Rename onnxruntime directories to a generic format. +$ort_dirs = Get-ChildItem -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-*" -Directory +foreach ($ort_dir in $ort_dirs) { + $dirname = Split-Path -Path $ort_dir -Leaf + $lastHyphenIndex = $dirname.LastIndexOf('-') + if ($lastHyphenIndex -gt -1) { + $newName = $dirname.Substring(0, $lastHyphenIndex) + $newPath = Join-Path -Path $ort_dir.Parent.FullName -ChildPath $newName + Write-Output "Renaming '$($ort_dir.FullName)' to '$newPath'" + Rename-Item -Path $ort_dir.FullName -NewName $newName + } +} From cc8b267ce2270659d126c7f416a99a08cb9e4b53 Mon Sep 17 00:00:00 2001 From: Jaswanth51 Date: Tue, 2 Sep 2025 11:32:23 +0530 Subject: [PATCH 095/138] Revert "Sync with Microsoft ONNX Runtime - 01/09/2025 (#801)" This reverts commit 2f1ad9d04d6c900d3c2749838f8196e720456e81. --- cmake/CMakeLists.txt | 8 +- .../external/onnxruntime_external_deps.cmake | 61 ++- cmake/onnxruntime.cmake | 13 +- cmake/onnxruntime_common.cmake | 57 ++- cmake/onnxruntime_java.cmake | 4 +- cmake/onnxruntime_nodejs.cmake | 1 - cmake/onnxruntime_unittests.cmake | 4 - .../cpuinfo/patch_vcpkg_arm64ec_support.patch | 91 ---- .../cpuinfo/patch_vcpkg_arm64ec_support.patch | 91 ---- cmake/vcpkg-ports/cpuinfo/portfile.cmake | 1 - .../NativeMethods.shared.cs | 98 ---- .../Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs | 40 -- .../EpCompatibilityTests.cs | 49 -- .../providers/cann/cann_provider_options.h | 2 - .../core/providers/utils/ort_graph_to_proto.h | 8 +- .../core/session/onnxruntime_c_api.h | 3 +- .../core/session/onnxruntime_cxx_api.h | 125 +---- .../core/session/onnxruntime_cxx_inline.h | 20 - .../onnxruntime_ep_device_ep_metadata_keys.h | 5 +- .../main/java/ai/onnxruntime/OnnxRuntime.java | 18 +- .../java/ai/onnxruntime/OrtEnvironment.java | 82 +-- .../main/java/ai/onnxruntime/OrtEpDevice.java | 117 ----- .../ai/onnxruntime/OrtHardwareDevice.java | 156 ------ .../OrtModelCompilationOptions.java | 280 ----------- .../main/java/ai/onnxruntime/OrtSession.java | 78 +-- .../src/main/java/ai/onnxruntime/OrtUtil.java | 51 +- .../ai/onnxruntime/providers/CoreMLFlags.java | 4 +- .../ai/onnxruntime/providers/NNAPIFlags.java | 4 +- .../onnxruntime/{ => providers}/OrtFlags.java | 4 +- java/src/main/native/OrtJniUtil.c | 30 -- java/src/main/native/OrtJniUtil.h | 2 - .../main/native/ai_onnxruntime_OnnxRuntime.c | 13 - .../native/ai_onnxruntime_OrtEnvironment.c | 70 --- .../main/native/ai_onnxruntime_OrtEpDevice.c | 82 --- .../native/ai_onnxruntime_OrtHardwareDevice.c | 96 ---- ...i_onnxruntime_OrtModelCompilationOptions.c | 193 -------- ...ai_onnxruntime_OrtSession_SessionOptions.c | 53 +- .../java/ai/onnxruntime/CompileApiTest.java | 53 -- .../java/ai/onnxruntime/EpDeviceTest.java | 123 ----- js/node/src/inference_session_wrap.cc | 2 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 15 +- .../cpu/moe/moe_quantization_cpu.cc | 11 +- .../quantization/dynamic_quantize_matmul.cc | 17 +- .../core/common/cpuid_arch_definition.h | 2 +- onnxruntime/core/graph/abi_graph_types.h | 10 + onnxruntime/core/graph/ep_api_types.cc | 29 +- onnxruntime/core/graph/ep_api_types.h | 3 + .../core/graph/model_editor_api_types.h | 5 + .../core/mlas/lib/kleidiai/sgemm_kleidiai.cpp | 73 ++- .../providers/cann/cann_execution_provider.cc | 9 +- .../cann/cann_execution_provider_info.cc | 4 - .../cann/cann_execution_provider_info.h | 1 - .../providers/cann/cann_provider_factory.cc | 2 - .../nv_tensorrt_rtx/nv_execution_provider.cc | 151 +++++- .../nv_tensorrt_rtx/nv_execution_provider.h | 2 + .../qnn/builder/opbuilder/conv_op_builder.cc | 14 +- .../shared_library/provider_ort_api_init.cc | 4 +- .../core/providers/vitisai/imp/global_api.cc | 6 +- onnxruntime/core/session/onnxruntime_c_api.cc | 38 +- onnxruntime/core/session/ort_apis.h | 2 +- .../plugin_ep/ep_factory_provider_bridge.cc | 7 - .../plugin_ep/ep_factory_provider_bridge.h | 15 +- .../core/session/plugin_ep/ep_library.h | 1 - .../plugin_ep/ep_library_provider_bridge.cc | 4 +- .../plugin_ep/ep_library_provider_bridge.h | 9 +- .../core/session/provider_bridge_ort.cc | 1 - onnxruntime/core/session/utils.cc | 5 +- .../python/onnxruntime_pybind_state.cc | 19 +- .../execution_providers/qnn/preprocess.py | 24 +- onnxruntime/test/autoep/library/ep_arena.h | 3 - onnxruntime/test/contrib_ops/moe_test.cc | 55 -- .../test/framework/ep_compatibility_test.cc | 29 -- .../test/platform/device_discovery_test.cc | 4 +- ...nnxruntime_test_python_ep_compatibility.py | 46 -- ...me_test_python_nv_tensorrt_rtx_ep_tests.py | 468 ------------------ .../custom_op_library/custom_op_library.cc | 2 +- .../github/windows/extract_nuget_files.ps1 | 148 +++--- .../windows/extract_nuget_files_gpu.ps1 | 86 +--- 78 files changed, 509 insertions(+), 3007 deletions(-) delete mode 100644 cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch delete mode 100644 cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch delete mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs delete mode 100644 java/src/main/java/ai/onnxruntime/OrtEpDevice.java delete mode 100644 java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java delete mode 100644 java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java rename java/src/main/java/ai/onnxruntime/{ => providers}/OrtFlags.java (88%) delete mode 100644 java/src/main/native/ai_onnxruntime_OrtEpDevice.c delete mode 100644 java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c delete mode 100644 java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c delete mode 100644 java/src/test/java/ai/onnxruntime/CompileApiTest.java delete mode 100644 java/src/test/java/ai/onnxruntime/EpDeviceTest.java delete mode 100644 onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py delete mode 100644 onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 40e6a8da28e45..98548957d0b42 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1607,6 +1607,7 @@ if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux") endif() endif() + #Now the 'onnxruntime_EXTERNAL_LIBRARIES' variable should be sealed. It will be used in onnxruntime.cmake which will be included in the next. #The order of the following targets matters. Right depends on left. If target A appears before target B. Then A.cmake can not use variables defined in B.cmake. set(ONNXRUNTIME_CMAKE_FILES onnxruntime_flatbuffers onnxruntime_common onnxruntime_mlas onnxruntime_graph onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_providers onnxruntime_optimizer onnxruntime_session ${ONNXRUNTIME_EAGER_CMAKE_FILE_NAME}) @@ -1622,6 +1623,9 @@ if (onnxruntime_USE_WINML) list(APPEND ONNXRUNTIME_CMAKE_FILES winml) endif() # if (onnxruntime_USE_WINML) +if (onnxruntime_BUILD_APPLE_FRAMEWORK AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS|visionOS|tvOS") + message(FATAL_ERROR "onnxruntime_BUILD_APPLE_FRAMEWORK can only be enabled for macOS or iOS or visionOS or tvOS.") +endif() list(APPEND ONNXRUNTIME_CMAKE_FILES onnxruntime) if (onnxruntime_BUILD_JAVA) @@ -1686,8 +1690,8 @@ if (WIN32 AND NOT GDK_PLATFORM AND NOT CMAKE_CROSSCOMPILING) endif() endif() -foreach(onnxruntime_cmake_file ${ONNXRUNTIME_CMAKE_FILES}) - include(${onnxruntime_cmake_file}.cmake) +foreach(target_name ${ONNXRUNTIME_CMAKE_FILES}) + include(${target_name}.cmake) endforeach() if (UNIX) option(BUILD_PKGCONFIG_FILES "Build and install pkg-config files" ON) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 827be3e6dea2a..3095968795d1a 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -313,32 +313,41 @@ onnxruntime_fetchcontent_makeavailable(nlohmann_json) if (onnxruntime_ENABLE_CPUINFO) # Adding pytorch CPU info library # TODO!! need a better way to find out the supported architectures - set(CPUINFO_SUPPORTED FALSE) + list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) if (APPLE) - list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) if (CMAKE_OSX_ARCHITECTURES_LEN LESS_EQUAL 1) set(CPUINFO_SUPPORTED TRUE) - else() - message(WARNING "cpuinfo is not supported when CMAKE_OSX_ARCHITECTURES has more than one value.") + elseif (onnxruntime_BUILD_APPLE_FRAMEWORK) + # We stitch multiple static libraries together when onnxruntime_BUILD_APPLE_FRAMEWORK is true, + # but that would not work for universal static libraries + message(FATAL_ERROR "universal binary is not supported for apple framework") endif() - elseif (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - # if xnnpack is enabled in a wasm build it needs clog from cpuinfo, but we won't internally use cpuinfo. - if (onnxruntime_USE_XNNPACK) - set(CPUINFO_SUPPORTED TRUE) - endif() - elseif (WIN32) - set(CPUINFO_SUPPORTED TRUE) else() - if (onnxruntime_target_platform MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") - set(CPUINFO_SUPPORTED TRUE) + # if xnnpack is enabled in a wasm build it needs clog from cpuinfo, but we won't internally use cpuinfo + # so we don't set CPUINFO_SUPPORTED in the CXX flags below. + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_USE_XNNPACK) + set(CPUINFO_SUPPORTED FALSE) else() - message(WARNING "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo.") + set(CPUINFO_SUPPORTED TRUE) + endif() + if (WIN32) + # There's an error when linking with cpuinfo on arm64ec with a vcpkg build (--use_vcpkg). + # TODO Fix it and then re-enable cpuinfo on arm64ec. + if (onnxruntime_target_platform STREQUAL "ARM64EC") + set(CPUINFO_SUPPORTED FALSE) + else() + set(CPUINFO_SUPPORTED TRUE) + endif() + elseif (NOT ${onnxruntime_target_platform} MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") + message(WARNING + "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo. " + "cpuinfo not included." + ) + set(CPUINFO_SUPPORTED FALSE) endif() endif() - - if(NOT CPUINFO_SUPPORTED) - message(WARNING "onnxruntime_ENABLE_CPUINFO was set but cpuinfo is not supported.") - endif() +else() + set(CPUINFO_SUPPORTED FALSE) endif() if (CPUINFO_SUPPORTED) @@ -349,26 +358,23 @@ if (CPUINFO_SUPPORTED) # if this is a wasm build with xnnpack (only type of wasm build where cpuinfo is involved) # we do not use cpuinfo in ORT code, so don't define CPUINFO_SUPPORTED. - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_USE_XNNPACK) - else() - add_compile_definitions(CPUINFO_SUPPORTED) + if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + string(APPEND CMAKE_CXX_FLAGS " -DCPUINFO_SUPPORTED") endif() + set(CPUINFO_BUILD_TOOLS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE INTERNAL "") if (onnxruntime_target_platform STREQUAL "ARM64EC" OR onnxruntime_target_platform STREQUAL "ARM64") - message(STATUS "Applying patches for Windows ARM64/ARM64EC in cpuinfo") + message(STATUS "Applying a patch for Windows ARM64/ARM64EC in cpuinfo") onnxruntime_fetchcontent_declare( pytorch_cpuinfo URL ${DEP_URL_pytorch_cpuinfo} URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} EXCLUDE_FROM_ALL - PATCH_COMMAND - ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch && - # https://github.com/pytorch/cpuinfo/pull/324 - ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch + PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch FIND_PACKAGE_ARGS NAMES cpuinfo ) else() @@ -578,7 +584,8 @@ endif() set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES_XNNPACK} ${WIL_TARGET} nlohmann_json::nlohmann_json onnx onnx_proto ${PROTOBUF_LIB} re2::re2 Boost::mp11 safeint_interface - flatbuffers::flatbuffers ${GSL_TARGET} ${ABSEIL_LIBS} date::date Eigen3::Eigen) + flatbuffers::flatbuffers ${GSL_TARGET} ${ABSEIL_LIBS} date::date + ${ONNXRUNTIME_CLOG_TARGET_NAME} Eigen3::Eigen) # The source code of onnx_proto is generated, we must build this lib first before starting to compile the other source code that uses ONNX protobuf types. # The other libs do not have the problem. All the sources are already there. We can compile them in any order. diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index e1d98109208d4..010696a61022c 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -350,19 +350,8 @@ if (winml_is_inbox) endif() endif() -# Assemble the Apple static framework +# Assemble the Apple static framework (iOS and macOS) if(onnxruntime_BUILD_APPLE_FRAMEWORK) - if (NOT CMAKE_SYSTEM_NAME MATCHES "Darwin|iOS|visionOS|tvOS") - message(FATAL_ERROR "onnxruntime_BUILD_APPLE_FRAMEWORK can only be enabled for macOS or iOS or visionOS or tvOS.") - endif() - - list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) - if (CMAKE_OSX_ARCHITECTURES_LEN GREATER 1) - # We stitch multiple static libraries together when onnxruntime_BUILD_APPLE_FRAMEWORK is true, - # but that would not work for universal static libraries - message(FATAL_ERROR "universal binary is not supported for apple framework") - endif() - # when building for mac catalyst, the CMAKE_OSX_SYSROOT is set to MacOSX as well, to avoid duplication, # we specify as `-macabi` in the name of the output static apple framework directory. if (PLATFORM_NAME STREQUAL "macabi") diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 0218994e537a0..d927489372e7c 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -194,10 +194,59 @@ if(APPLE) target_link_libraries(onnxruntime_common PRIVATE "-framework Foundation") endif() -if(CPUINFO_SUPPORTED) - # Link cpuinfo if supported - onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo) +if(MSVC) + if(onnxruntime_target_platform STREQUAL "ARM64") + set(ARM64 TRUE) + elseif (onnxruntime_target_platform STREQUAL "ARM") + set(ARM TRUE) + elseif(onnxruntime_target_platform STREQUAL "x64") + set(X64 TRUE) + elseif(onnxruntime_target_platform STREQUAL "x86") + set(X86 TRUE) + endif() +elseif(APPLE) + if(CMAKE_OSX_ARCHITECTURES_LEN LESS_EQUAL 1) + set(X64 TRUE) + endif() +elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") + set(ARM TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") + set(ARM64 TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") + set(X86_64 TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") + set(X86 TRUE) + endif() + else() + execute_process( + COMMAND ${CMAKE_C_COMPILER} -dumpmachine + OUTPUT_VARIABLE dumpmachine_output + ERROR_QUIET + ) + if(dumpmachine_output MATCHES "^arm64.*") + set(ARM64 TRUE) + elseif(dumpmachine_output MATCHES "^arm.*") + set(ARM TRUE) + elseif(dumpmachine_output MATCHES "^aarch64.*") + set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(RISCV64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") + set(X86 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + set(X86_64 TRUE) + endif() + endif() +endif() + +if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) + # Link cpuinfo if supported + if (CPUINFO_SUPPORTED) + onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo ${ONNXRUNTIME_CLOG_TARGET_NAME}) + endif() endif() if (NOT onnxruntime_BUILD_SHARED_LIB) diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index 7da63b523be70..6b638b3e5d8bc 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -159,7 +159,7 @@ if (WIN32) if(NOT onnxruntime_ENABLE_STATIC_ANALYSIS) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_JNI_DIR}/$) - if (TARGET onnxruntime_providers_shared) + if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() if (onnxruntime_USE_CUDA) @@ -207,7 +207,7 @@ if (WIN32) else() add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_JNI_DIR}/$) - if (TARGET onnxruntime_providers_shared) + if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() if (onnxruntime_USE_CUDA) diff --git a/cmake/onnxruntime_nodejs.cmake b/cmake/onnxruntime_nodejs.cmake index cce0810c5bbe8..b28bda6c94276 100644 --- a/cmake/onnxruntime_nodejs.cmake +++ b/cmake/onnxruntime_nodejs.cmake @@ -10,7 +10,6 @@ include(node_helper.cmake) # setup ARCH if (APPLE) - list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) if (CMAKE_OSX_ARCHITECTURES_LEN GREATER 1) message(FATAL_ERROR "CMake.js does not support multi-architecture for macOS") endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index b31849440c426..6847db64004ca 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1640,10 +1640,6 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") add_custom_command(TARGET onnxruntime_providers_qnn POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} ${JAVA_NATIVE_TEST_DIR}) endif() - if (WIN32) - set(EXAMPLE_PLUGIN_EP_DST_FILE_NAME $,$,$>) - add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_NATIVE_TEST_DIR}/${EXAMPLE_PLUGIN_EP_DST_FILE_NAME}) - endif() # delegate to gradle's test runner diff --git a/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch deleted file mode 100644 index af0f039b6c2a3..0000000000000 --- a/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch +++ /dev/null @@ -1,91 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index aedc983..dab589e 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am - ENDIF() - IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") - SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") -+ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") -+ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. -+ IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") -+ SET(CPUINFO_TARGET_PROCESSOR "x86") -+ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") -+ SET(CPUINFO_TARGET_PROCESSOR "x86_64") -+ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID MATCHES "^(ARM64|ARM64EC)$") -+ SET(CPUINFO_TARGET_PROCESSOR "arm64") -+ ELSE() -+ MESSAGE(FATAL_ERROR "Unsupported MSVC compiler architecture ID \"${CMAKE_C_COMPILER_ARCHITECTURE_ID}\"") -+ ENDIF() - ELSEIF(CMAKE_GENERATOR MATCHES "^Visual Studio " AND CMAKE_VS_PLATFORM_NAME) - IF(CMAKE_VS_PLATFORM_NAME STREQUAL "Win32") - SET(CPUINFO_TARGET_PROCESSOR "x86") -@@ -88,7 +99,7 @@ ENDIF() - - # ---[ Build flags - SET(CPUINFO_SUPPORTED_PLATFORM TRUE) --IF(NOT CMAKE_SYSTEM_PROCESSOR) -+IF(NOT CPUINFO_TARGET_PROCESSOR) - IF(NOT IOS) - MESSAGE(WARNING - "Target processor architecture is not specified. " -@@ -201,12 +212,12 @@ IF(CPUINFO_SUPPORTED_PLATFORM) - src/arm/linux/chipset.c - src/arm/linux/midr.c - src/arm/linux/hwcap.c) -- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]") -+ IF(CPUINFO_TARGET_PROCESSOR MATCHES "^armv[5-8]") - LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch32-isa.c) - IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND ANDROID_ABI STREQUAL "armeabi") - SET_SOURCE_FILES_PROPERTIES(src/arm/linux/aarch32-isa.c PROPERTIES COMPILE_FLAGS -marm) - ENDIF() -- ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)$") -+ ELSEIF(CPUINFO_TARGET_PROCESSOR MATCHES "^(aarch64|arm64)$") - LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch64-isa.c) - ENDIF() - ELSEIF(IS_APPLE_OS AND CPUINFO_TARGET_PROCESSOR MATCHES "arm64.*") -@@ -395,7 +406,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) - TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PRIVATE _GNU_SOURCE=1) - ENDIF() - -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a)$") -+ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a)$") - ADD_EXECUTABLE(atm7029b-tablet-test test/mock/atm7029b-tablet.cc) - TARGET_INCLUDE_DIRECTORIES(atm7029b-tablet-test BEFORE PRIVATE test/mock) - TARGET_LINK_LIBRARIES(atm7029b-tablet-test PRIVATE cpuinfo_mock gtest) -@@ -577,7 +588,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) - ADD_TEST(NAME xperia-sl-test COMMAND xperia-sl-test) - ENDIF() - -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") -+ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") - ADD_EXECUTABLE(alcatel-revvl-test test/mock/alcatel-revvl.cc) - TARGET_INCLUDE_DIRECTORIES(alcatel-revvl-test BEFORE PRIVATE test/mock) - TARGET_LINK_LIBRARIES(alcatel-revvl-test PRIVATE cpuinfo_mock gtest) -@@ -774,7 +785,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) - ADD_TEST(NAME xperia-c4-dual-test COMMAND xperia-c4-dual-test) - ENDIF() - -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64)$") -+ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(i686|x86_64)$") - ADD_EXECUTABLE(alldocube-iwork8-test test/mock/alldocube-iwork8.cc) - TARGET_INCLUDE_DIRECTORIES(alldocube-iwork8-test BEFORE PRIVATE test/mock) - TARGET_LINK_LIBRARIES(alldocube-iwork8-test PRIVATE cpuinfo_mock gtest) -@@ -831,7 +842,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_UNIT_TESTS) - ADD_TEST(NAME brand-string-test COMMAND brand-string-test) - ENDIF() - -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") -+ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") - ADD_LIBRARY(android_properties_interface STATIC test/name/android-properties-interface.c) - CPUINFO_TARGET_ENABLE_C99(android_properties_interface) - CPUINFO_TARGET_RUNTIME_LIBRARY(android_properties_interface) -@@ -879,7 +890,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_TOOLS) - TARGET_LINK_LIBRARIES(cache-info PRIVATE cpuinfo) - INSTALL(TARGETS cache-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) - -- IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") -+ IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") - ADD_EXECUTABLE(auxv-dump tools/auxv-dump.c) - CPUINFO_TARGET_ENABLE_C99(auxv-dump) - CPUINFO_TARGET_RUNTIME_LIBRARY(auxv-dump) diff --git a/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch deleted file mode 100644 index af0f039b6c2a3..0000000000000 --- a/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch +++ /dev/null @@ -1,91 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index aedc983..dab589e 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am - ENDIF() - IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") - SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") -+ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") -+ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. -+ IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") -+ SET(CPUINFO_TARGET_PROCESSOR "x86") -+ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") -+ SET(CPUINFO_TARGET_PROCESSOR "x86_64") -+ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID MATCHES "^(ARM64|ARM64EC)$") -+ SET(CPUINFO_TARGET_PROCESSOR "arm64") -+ ELSE() -+ MESSAGE(FATAL_ERROR "Unsupported MSVC compiler architecture ID \"${CMAKE_C_COMPILER_ARCHITECTURE_ID}\"") -+ ENDIF() - ELSEIF(CMAKE_GENERATOR MATCHES "^Visual Studio " AND CMAKE_VS_PLATFORM_NAME) - IF(CMAKE_VS_PLATFORM_NAME STREQUAL "Win32") - SET(CPUINFO_TARGET_PROCESSOR "x86") -@@ -88,7 +99,7 @@ ENDIF() - - # ---[ Build flags - SET(CPUINFO_SUPPORTED_PLATFORM TRUE) --IF(NOT CMAKE_SYSTEM_PROCESSOR) -+IF(NOT CPUINFO_TARGET_PROCESSOR) - IF(NOT IOS) - MESSAGE(WARNING - "Target processor architecture is not specified. " -@@ -201,12 +212,12 @@ IF(CPUINFO_SUPPORTED_PLATFORM) - src/arm/linux/chipset.c - src/arm/linux/midr.c - src/arm/linux/hwcap.c) -- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]") -+ IF(CPUINFO_TARGET_PROCESSOR MATCHES "^armv[5-8]") - LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch32-isa.c) - IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND ANDROID_ABI STREQUAL "armeabi") - SET_SOURCE_FILES_PROPERTIES(src/arm/linux/aarch32-isa.c PROPERTIES COMPILE_FLAGS -marm) - ENDIF() -- ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)$") -+ ELSEIF(CPUINFO_TARGET_PROCESSOR MATCHES "^(aarch64|arm64)$") - LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch64-isa.c) - ENDIF() - ELSEIF(IS_APPLE_OS AND CPUINFO_TARGET_PROCESSOR MATCHES "arm64.*") -@@ -395,7 +406,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) - TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PRIVATE _GNU_SOURCE=1) - ENDIF() - -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a)$") -+ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a)$") - ADD_EXECUTABLE(atm7029b-tablet-test test/mock/atm7029b-tablet.cc) - TARGET_INCLUDE_DIRECTORIES(atm7029b-tablet-test BEFORE PRIVATE test/mock) - TARGET_LINK_LIBRARIES(atm7029b-tablet-test PRIVATE cpuinfo_mock gtest) -@@ -577,7 +588,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) - ADD_TEST(NAME xperia-sl-test COMMAND xperia-sl-test) - ENDIF() - -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") -+ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") - ADD_EXECUTABLE(alcatel-revvl-test test/mock/alcatel-revvl.cc) - TARGET_INCLUDE_DIRECTORIES(alcatel-revvl-test BEFORE PRIVATE test/mock) - TARGET_LINK_LIBRARIES(alcatel-revvl-test PRIVATE cpuinfo_mock gtest) -@@ -774,7 +785,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) - ADD_TEST(NAME xperia-c4-dual-test COMMAND xperia-c4-dual-test) - ENDIF() - -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64)$") -+ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(i686|x86_64)$") - ADD_EXECUTABLE(alldocube-iwork8-test test/mock/alldocube-iwork8.cc) - TARGET_INCLUDE_DIRECTORIES(alldocube-iwork8-test BEFORE PRIVATE test/mock) - TARGET_LINK_LIBRARIES(alldocube-iwork8-test PRIVATE cpuinfo_mock gtest) -@@ -831,7 +842,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_UNIT_TESTS) - ADD_TEST(NAME brand-string-test COMMAND brand-string-test) - ENDIF() - -- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") -+ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") - ADD_LIBRARY(android_properties_interface STATIC test/name/android-properties-interface.c) - CPUINFO_TARGET_ENABLE_C99(android_properties_interface) - CPUINFO_TARGET_RUNTIME_LIBRARY(android_properties_interface) -@@ -879,7 +890,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_TOOLS) - TARGET_LINK_LIBRARIES(cache-info PRIVATE cpuinfo) - INSTALL(TARGETS cache-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) - -- IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") -+ IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") - ADD_EXECUTABLE(auxv-dump tools/auxv-dump.c) - CPUINFO_TARGET_ENABLE_C99(auxv-dump) - CPUINFO_TARGET_RUNTIME_LIBRARY(auxv-dump) diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index eeb0007195ca3..3fcf76b7adafc 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -11,7 +11,6 @@ vcpkg_from_github( HEAD_REF master PATCHES patch_cpuinfo_h_for_arm64ec.patch - patch_vcpkg_arm64ec_support.patch # https://github.com/pytorch/cpuinfo/pull/324 ) vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 3c92400715740..8cca2b42e987a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -368,88 +368,6 @@ public struct OrtApi public IntPtr EpDevice_Device; public IntPtr GetEpApi; public IntPtr GetTensorSizeInBytes; - - public IntPtr AllocatorGetStats; - - public IntPtr CreateMemoryInfo_V2; - public IntPtr MemoryInfoGetDeviceMemType; - public IntPtr MemoryInfoGetVendorId; - - public IntPtr ValueInfo_GetValueProducer; - public IntPtr ValueInfo_GetValueNumConsumers; - public IntPtr ValueInfo_GetValueConsumers; - public IntPtr ValueInfo_GetInitializerValue; - public IntPtr ValueInfo_GetExternalInitializerInfo; - public IntPtr ValueInfo_IsRequiredGraphInput; - public IntPtr ValueInfo_IsOptionalGraphInput; - public IntPtr ValueInfo_IsGraphOutput; - public IntPtr ValueInfo_IsConstantInitializer; - public IntPtr ValueInfo_IsFromOuterScope; - public IntPtr Graph_GetName; - public IntPtr Graph_GetModelPath; - public IntPtr Graph_GetOnnxIRVersion; - public IntPtr Graph_GetNumOperatorSets; - public IntPtr Graph_GetOperatorSets; - public IntPtr Graph_GetNumInputs; - public IntPtr Graph_GetInputs; - public IntPtr Graph_GetNumOutputs; - public IntPtr Graph_GetOutputs; - public IntPtr Graph_GetNumInitializers; - public IntPtr Graph_GetInitializers; - public IntPtr Graph_GetNumNodes; - public IntPtr Graph_GetNodes; - public IntPtr Graph_GetParentNode; - public IntPtr Graph_GetGraphView; - public IntPtr Node_GetId; - public IntPtr Node_GetName; - public IntPtr Node_GetOperatorType; - public IntPtr Node_GetDomain; - public IntPtr Node_GetSinceVersion; - public IntPtr Node_GetNumInputs; - public IntPtr Node_GetInputs; - public IntPtr Node_GetNumOutputs; - public IntPtr Node_GetOutputs; - public IntPtr Node_GetNumImplicitInputs; - public IntPtr Node_GetImplicitInputs; - public IntPtr Node_GetNumAttributes; - public IntPtr Node_GetAttributes; - public IntPtr Node_GetAttributeByName; - public IntPtr Node_GetTensorAttributeAsOrtValue; - public IntPtr OpAttr_GetType; - public IntPtr OpAttr_GetName; - public IntPtr Node_GetNumSubgraphs; - public IntPtr Node_GetSubgraphs; - public IntPtr Node_GetGraph; - public IntPtr Node_GetEpName; - public IntPtr ReleaseExternalInitializerInfo; - public IntPtr ExternalInitializerInfo_GetFilePath; - public IntPtr ExternalInitializerInfo_GetFileOffset; - public IntPtr ExternalInitializerInfo_GetByteSize; - - public IntPtr GetRunConfigEntry; - - public IntPtr EpDevice_MemoryInfo; - - public IntPtr CreateSharedAllocator; - public IntPtr GetSharedAllocator; - public IntPtr ReleaseSharedAllocator; - - public IntPtr GetTensorData; - - public IntPtr GetSessionOptionsConfigEntries; - - public IntPtr SessionGetMemoryInfoForInputs; - public IntPtr SessionGetMemoryInfoForOutputs; - public IntPtr SessionGetEpDeviceForInputs; - - public IntPtr CreateSyncStreamForEpDevice; - public IntPtr SyncStream_GetHandle; - public IntPtr ReleaseSyncStream; - - public IntPtr CopyTensors; - - public IntPtr Graph_GetModelMetadata; - public IntPtr GetModelCompatibilityForEpDevices; } internal static class NativeMethods @@ -786,10 +704,6 @@ static NativeMethods() (DSessionOptionsSetEpSelectionPolicyDelegate)Marshal.GetDelegateForFunctionPointer( api_.SessionOptionsSetEpSelectionPolicyDelegate, typeof(DSessionOptionsSetEpSelectionPolicyDelegate)); - - OrtGetModelCompatibilityForEpDevices = (DOrtGetModelCompatibilityForEpDevices)Marshal.GetDelegateForFunctionPointer( - api_.GetModelCompatibilityForEpDevices, - typeof(DOrtGetModelCompatibilityForEpDevices)); } internal class NativeLib @@ -2542,18 +2456,6 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public static DOrtGetEpDevices OrtGetEpDevices; - /// - /// Validate compiled model compatibility for the provided EP devices. - /// - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtGetModelCompatibilityForEpDevices( - IntPtr[] /* const OrtEpDevice* const* */ ep_devices, - UIntPtr /* size_t */ num_ep_devices, - byte[] /* const char* */ compatibility_info, - out int /* OrtCompiledModelCompatibility */ out_status); - - public static DOrtGetModelCompatibilityForEpDevices OrtGetModelCompatibilityForEpDevices; - /// /// Add execution provider devices to the session options. /// Priority is based on the order of the OrtEpDevice instances. Highest priority first. diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index 052d5899b52c0..5c70808b82be1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -7,21 +7,6 @@ namespace Microsoft.ML.OnnxRuntime { - /// - /// Represents the compatibility status of a pre-compiled model with one or more execution provider devices. - /// - /// - /// This enum is used to determine whether a pre-compiled model can be used with specific execution providers - /// and devices, or if recompilation is needed. - /// - public enum OrtCompiledModelCompatibility - { - EP_NOT_APPLICABLE = 0, - EP_SUPPORTED_OPTIMAL = 1, - EP_SUPPORTED_PREFER_RECOMPILATION = 2, - EP_UNSUPPORTED = 3, - } - /// /// Delegate for logging function callback. /// Supply your function and register it with the environment to receive logging callbacks via @@ -376,31 +361,6 @@ public string[] GetAvailableProviders() } } - /// - /// Validate a compiled model's compatibility information for one or more EP devices. - /// - /// The list of EP devices to validate against. - /// The compatibility string from the precompiled model to validate. - /// OrtCompiledModelCompatibility enum value denoting the compatibility status - public OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( - IReadOnlyList epDevices, string compatibilityInfo) - { - if (epDevices == null || epDevices.Count == 0) - throw new ArgumentException("epDevices must be non-empty", nameof(epDevices)); - - var devicePtrs = new IntPtr[epDevices.Count]; - for (int i = 0; i < epDevices.Count; ++i) - { - devicePtrs[i] = epDevices[i].Handle; - } - - var infoUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(compatibilityInfo); - NativeApiStatus.VerifySuccess( - NativeMethods.OrtGetModelCompatibilityForEpDevices( - devicePtrs, (UIntPtr)devicePtrs.Length, infoUtf8, out int status)); - return (OrtCompiledModelCompatibility)status; - } - /// /// Get/Set log level property of OrtEnv instance diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs deleted file mode 100644 index 103fe5bc10106..0000000000000 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// not supported on mobile platforms -#if !(ANDROID || IOS) - -namespace Microsoft.ML.OnnxRuntime.Tests; - -using System; -using System.Linq; -using Xunit; -using System.Collections.Generic; - -public class EpCompatibilityTests -{ - private readonly OrtEnv ortEnvInstance = OrtEnv.Instance(); - - private IReadOnlyList GetDevices() - { - var epDevices = ortEnvInstance.GetEpDevices(); - Assert.NotNull(epDevices); - Assert.NotEmpty(epDevices); - return epDevices; - } - - [Fact] - public void GetEpCompatibility_InvalidArgs() - { - Assert.Throws(() => ortEnvInstance.GetModelCompatibilityForEpDevices(null, "info")); - Assert.Throws(() => ortEnvInstance.GetModelCompatibilityForEpDevices(new List(), "info")); - } - - [Fact] - public void GetEpCompatibility_SingleDeviceCpuProvider() - { - var devices = GetDevices(); - var someInfo = "arbitrary-compat-string"; - - // Use CPU device - var cpu = devices.First(d => d.EpName == "CPUExecutionProvider"); - Assert.NotNull(cpu); - var selected = new List { cpu }; - var status = ortEnvInstance.GetModelCompatibilityForEpDevices(selected, someInfo); - - // CPU defaults to not applicable in this scenario - Assert.Equal(OrtCompiledModelCompatibility.EP_NOT_APPLICABLE, status); - } -} -#endif diff --git a/include/onnxruntime/core/providers/cann/cann_provider_options.h b/include/onnxruntime/core/providers/cann/cann_provider_options.h index 4b33ee77a892e..51b423e68110a 100644 --- a/include/onnxruntime/core/providers/cann/cann_provider_options.h +++ b/include/onnxruntime/core/providers/cann/cann_provider_options.h @@ -15,8 +15,6 @@ struct OrtCANNProviderOptions { onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena int enable_cann_graph; // Flag indicating if prioritizing the use of // CANN's graph-running capabilities - int enable_cann_subgraph; // Flag indicating whether to generate subgraph - // automaticly int dump_graphs; // Flag indicating if dumping graphs int dump_om_model; // Flag indicating if dumping om model std::string precision_mode; // Operator Precision Mode diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 28ce4439fdc7e..21aa797ce16eb 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::GraphProto& graph_proto, @@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, } onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); } } @@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { const OrtApi& ort_api = Ort::GetApi(); const char* attr_name = nullptr; @@ -766,7 +766,7 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr // TensorProto as an attribute value doesn't require a name. OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); Ort::Value tensor(ort_value); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index f137d88e5fb8a..9ae6174817b7c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6079,6 +6079,7 @@ struct OrtApi { /** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue. * + * \param[in] node The OrtNode instance. * \param[in] attribute The OrtOpAttr instance. * \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue. Must be freed with OrtApi::ReleaseValue. @@ -6087,7 +6088,7 @@ struct OrtApi { * * \since Version 1.23. */ - ORT_API2_STATUS(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, + ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 13675ab447ab1..c39e27088e8bc 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -79,19 +79,22 @@ struct Exception : std::exception { throw Ort::Exception(string, code) #endif +// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, +// it's in a template so that we can define a global variable in a header and make +// it transparent to the users of the API. +template +struct Global { + static const OrtApi* api_; +}; + +// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. +template #ifdef ORT_API_MANUAL_INIT -// If the macro ORT_API_MANUAL_INIT is defined, no static initialization -// will be performed. Instead, users must call InitApi() before using the -// ORT C++ APIs.. -// -// InitApi() sets the global API object using the default initialization -// logic. Users call this to initialize the ORT C++ APIs at a time that -// makes sense in their program. -inline void InitApi() noexcept; - -// InitApi(const OrtApi*) is used by custom operator libraries that are not -// linked to onnxruntime. It sets the global API object, which is required -// by the ORT C++ APIs. +const OrtApi* Global::api_{}; +inline void InitApi() noexcept { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } + +// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is +// required by C++ APIs. // // Example mycustomop.cc: // @@ -104,88 +107,22 @@ inline void InitApi() noexcept; // // ... // } // -inline void InitApi(const OrtApi* api) noexcept; -#endif - -namespace detail { -// This is used internally by the C++ API. This class holds the global -// variable that points to the OrtApi. -struct Global { - static const OrtApi* Api(const OrtApi* newValue = nullptr) noexcept { - // This block-level static will be initialized once when this function is - // first executed, delaying the call to DefaultInit() until it is first needed. - // - // When ORT_API_MANUAL_INIT is not defined, DefaultInit() calls - // OrtGetApiBase()->GetApi(), which may result in a shared library being - // loaded. - // - // Using a block-level static instead of a class-level static helps - // avoid issues with static initialization order and dynamic libraries - // loading other dynamic libraries. - // - // This makes it safe to include the C++ API headers in a shared library - // that is delay loaded or delay loads its dependencies. - // - // This DOES NOT make it safe to _use_ arbitrary ORT C++ APIs when - // initializing static members, however. - static const OrtApi* api = DefaultInit(); - - if (newValue) { - api = newValue; - } - - return api; - } - - private: - // Has different definitions based on ORT_API_MANUAL_INIT - static const OrtApi* DefaultInit() noexcept; - -#ifdef ORT_API_MANUAL_INIT - // Public APIs to set the OrtApi* to use. - friend void ::Ort::InitApi() noexcept; - friend void ::Ort::InitApi(const OrtApi*) noexcept; +inline void InitApi(const OrtApi* api) noexcept { Global::api_ = api; } +#else +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers. +// Please define ORT_API_MANUAL_INIT if it conerns you. +#pragma warning(disable : 26426) #endif -}; -} // namespace detail - -#ifdef ORT_API_MANUAL_INIT - -// See comments on declaration above for usage. -inline void InitApi(const OrtApi* api) noexcept { detail::Global::Api(api); } -inline void InitApi() noexcept { InitApi(OrtGetApiBase()->GetApi(ORT_API_VERSION)); } - -#ifdef _MSC_VER -// If you get a linker error about a mismatch here, you are trying to -// link two compilation units that have different definitions for -// ORT_API_MANUAL_INIT together. All compilation units must agree on the -// definition of ORT_API_MANUAL_INIT. -#pragma detect_mismatch("ORT_API_MANUAL_INIT", "enabled") +const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) #endif - -inline const OrtApi* detail::Global::DefaultInit() noexcept { - // When ORT_API_MANUAL_INIT is defined, there's no default init that can - // be done. - return nullptr; -} - -#else // ORT_API_MANUAL_INIT - -#ifdef _MSC_VER -// If you get a linker error about a mismatch here, you are trying to link -// two compilation units that have different definitions for -// ORT_API_MANUAL_INIT together. All compilation units must agree on the -// definition of ORT_API_MANUAL_INIT. -#pragma detect_mismatch("ORT_API_MANUAL_INIT", "disabled") #endif -inline const OrtApi* detail::Global::DefaultInit() noexcept { - return OrtGetApiBase()->GetApi(ORT_API_VERSION); -} -#endif // ORT_API_MANUAL_INIT - /// This returns a reference to the ORT C API. -inline const OrtApi& GetApi() noexcept { return *detail::Global::Api(); } +inline const OrtApi& GetApi() noexcept { return *Global::api_; } /// /// This function returns the onnxruntime version string @@ -1076,16 +1013,6 @@ struct EpDevice : detail::EpDeviceImpl { ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {}); }; -/** \brief Validate a compiled model's compatibility for one or more EP devices. - * - * Throws on error. Returns the resulting compatibility status. - * /// \param ep_devices The EP devices to check compatibility against. - * /// \param compatibility_info The compatibility string from the precompiled model to validate. - */ -OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( - const std::vector& ep_devices, - const char* compatibility_info); - /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 05c86ae4e0c58..d0089726812a3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -859,26 +859,6 @@ inline void CustomOpDomain::Add(const OrtCustomOp* op) { ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); } -inline OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( - const std::vector& ep_devices, - const char* compatibility_info) { - if (ep_devices.empty()) { - ORT_CXX_API_THROW("ep_devices is empty", ORT_INVALID_ARGUMENT); - } - - std::vector ptrs; - ptrs.reserve(ep_devices.size()); - for (const auto& d : ep_devices) ptrs.push_back(d); - - OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; - ThrowOnError(GetApi().GetModelCompatibilityForEpDevices( - reinterpret_cast(ptrs.data()), - ptrs.size(), - compatibility_info, - &status)); - return status; -} - inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string& adapter_path, OrtAllocator* allocator) { OrtLoraAdapter* p; diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index bbd6a43bb7a41..672103bedc437 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -12,7 +12,4 @@ static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; // Prefix for execution provider compatibility information stored in model metadata. // Used when generating EP context models to store compatibility strings for each EP. // Full key format: "ep_compatibility_info." -static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; - -// Key for the execution provider library path (for dynamically loaded EPs) -static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path"; +static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; \ No newline at end of file diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index 3bb61698f5da7..97423ffb37251 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -42,8 +42,6 @@ final class OnnxRuntime { private static final int ORT_API_VERSION_13 = 13; // Post 1.13 builds of the ORT API private static final int ORT_API_VERSION_14 = 14; - // Post 1.22 builds of the ORT API - private static final int ORT_API_VERSION_23 = 23; // The initial release of the ORT training API. private static final int ORT_TRAINING_API_VERSION_1 = 1; @@ -105,9 +103,6 @@ final class OnnxRuntime { /** The Training API handle. */ static long ortTrainingApiHandle; - /** The Compile API handle. */ - static long ortCompileApiHandle; - /** Is training enabled in the native library */ static boolean trainingEnabled; @@ -181,13 +176,12 @@ static synchronized void init() throws IOException { } load(ONNXRUNTIME_JNI_LIBRARY_NAME); - ortApiHandle = initialiseAPIBase(ORT_API_VERSION_23); + ortApiHandle = initialiseAPIBase(ORT_API_VERSION_14); if (ortApiHandle == 0L) { throw new IllegalStateException( "There is a mismatch between the ORT class files and the ORT native library, and the native library could not be loaded"); } - ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_23); - ortCompileApiHandle = initialiseCompileAPIBase(ortApiHandle); + ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_14); trainingEnabled = ortTrainingApiHandle != 0L; providers = initialiseProviders(ortApiHandle); version = initialiseVersion(); @@ -505,14 +499,6 @@ private static EnumSet initialiseProviders(long ortApiHandle) { */ private static native long initialiseTrainingAPIBase(long apiHandle, int apiVersionNumber); - /** - * Get a reference to the compile API struct. - * - * @param apiHandle The ORT API struct pointer. - * @return A pointer to the compile API struct. - */ - private static native long initialiseCompileAPIBase(long apiHandle); - /** * Gets the array of available providers. * diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 497772baf5357..8382ef06e26e5 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -8,11 +8,7 @@ import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; import java.util.EnumSet; -import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.logging.Logger; @@ -446,48 +442,6 @@ public static EnumSet getAvailableProviders() { return OnnxRuntime.providers.clone(); } - /** - * Registers an execution provider library with this OrtEnvironment. - * - * @param registrationName The name to register the library with (used to remove it later with - * {@link #unregisterExecutionProviderLibrary(String)}). - * @param libraryPath The path to the library binary on disk. - * @throws OrtException If the library could not be registered. - */ - public void registerExecutionProviderLibrary(String registrationName, String libraryPath) - throws OrtException { - registerExecutionProviderLibrary( - OnnxRuntime.ortApiHandle, nativeHandle, registrationName, libraryPath); - } - - /** - * Unregisters an execution provider library from this OrtEnvironment. - * - * @param registrationName The name the library was registered under. - * @throws OrtException If the library could not be removed. - */ - public void unregisterExecutionProviderLibrary(String registrationName) throws OrtException { - unregisterExecutionProviderLibrary(OnnxRuntime.ortApiHandle, nativeHandle, registrationName); - } - - /** - * Get the list of all execution provider and device combinations that are available. - * - * @see OrtSession.SessionOptions#addExecutionProvider(List, Map) - * @return The list of execution provider and device combinations. - * @throws OrtException If the devices could not be listed. - */ - public List getEpDevices() throws OrtException { - long[] deviceHandles = getEpDevices(OnnxRuntime.ortApiHandle, nativeHandle); - - List devicesList = new ArrayList<>(); - for (long deviceHandle : deviceHandles) { - devicesList.add(new OrtEpDevice(deviceHandle)); - } - - return Collections.unmodifiableList(devicesList); - } - /** * Creates the native object. * @@ -522,40 +476,6 @@ private static native long createHandle( */ private static native long getDefaultAllocator(long apiHandle) throws OrtException; - /** - * Registers the specified execution provider with this OrtEnvironment. - * - * @param apiHandle The API handle. - * @param nativeHandle The OrtEnvironment handle. - * @param registrationName The name of the execution provider. - * @param libraryPath The path to the execution provider binary. - * @throws OrtException If the registration failed. - */ - private static native void registerExecutionProviderLibrary( - long apiHandle, long nativeHandle, String registrationName, String libraryPath) - throws OrtException; - - /** - * Removes the specified execution provider from this OrtEnvironment. - * - * @param apiHandle The API handle. - * @param nativeHandle The OrtEnvironment handle. - * @param registrationName The name of the execution provider. - * @throws OrtException If the removal failed. - */ - private static native void unregisterExecutionProviderLibrary( - long apiHandle, long nativeHandle, String registrationName) throws OrtException; - - /** - * Gets handles for the EP device tuples available in this OrtEnvironment. - * - * @param apiHandle The API handle to use. - * @param nativeHandle The OrtEnvironment handle. - * @return An array of OrtEpDevice handles. - * @throws OrtException If the call failed. - */ - private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException; - /** * Closes the OrtEnvironment, frees the handle. * diff --git a/java/src/main/java/ai/onnxruntime/OrtEpDevice.java b/java/src/main/java/ai/onnxruntime/OrtEpDevice.java deleted file mode 100644 index f63dec1dbaf83..0000000000000 --- a/java/src/main/java/ai/onnxruntime/OrtEpDevice.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. - * Licensed under the MIT License. - */ -package ai.onnxruntime; - -import java.util.Map; - -/** A tuple of Execution Provider information and the hardware device. */ -public final class OrtEpDevice { - - private final long nativeHandle; - - private final String epName; - private final String epVendor; - private final Map epMetadata; - private final Map epOptions; - private final OrtHardwareDevice device; - - /** - * Construct an OrtEpDevice tuple from the native pointer. - * - * @param nativeHandle The native pointer. - */ - OrtEpDevice(long nativeHandle) { - this.nativeHandle = nativeHandle; - this.epName = getName(OnnxRuntime.ortApiHandle, nativeHandle); - this.epVendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle); - String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle); - this.epMetadata = OrtUtil.convertToMap(metadata); - String[][] options = getOptions(OnnxRuntime.ortApiHandle, nativeHandle); - this.epOptions = OrtUtil.convertToMap(options); - this.device = new OrtHardwareDevice(getDeviceHandle(OnnxRuntime.ortApiHandle, nativeHandle)); - } - - /** - * Return the native pointer. - * - * @return The native pointer. - */ - long getNativeHandle() { - return nativeHandle; - } - - /** - * Gets the EP name. - * - * @return The EP name. - */ - public String getName() { - return epName; - } - - /** - * Gets the vendor name. - * - * @return The vendor name. - */ - public String getVendor() { - return epVendor; - } - - /** - * Gets an unmodifiable view on the EP metadata. - * - * @return The EP metadata. - */ - public Map getMetadata() { - return epMetadata; - } - - /** - * Gets an unmodifiable view on the EP options. - * - * @return The EP options. - */ - public Map getOptions() { - return epOptions; - } - - /** - * Gets the device information. - * - * @return The device information. - */ - public OrtHardwareDevice getDevice() { - return device; - } - - @Override - public String toString() { - return "OrtEpDevice{" - + "epName='" - + epName - + '\'' - + ", epVendor='" - + epVendor - + '\'' - + ", epMetadata=" - + epMetadata - + ", epOptions=" - + epOptions - + ", device=" - + device - + '}'; - } - - private static native String getName(long apiHandle, long nativeHandle); - - private static native String getVendor(long apiHandle, long nativeHandle); - - private static native String[][] getMetadata(long apiHandle, long nativeHandle); - - private static native String[][] getOptions(long apiHandle, long nativeHandle); - - private static native long getDeviceHandle(long apiHandle, long nativeHandle); -} diff --git a/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java b/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java deleted file mode 100644 index bd99f5599fd14..0000000000000 --- a/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. - * Licensed under the MIT License. - */ -package ai.onnxruntime; - -import java.util.Map; -import java.util.logging.Logger; - -/** Hardware information for a specific device. */ -public final class OrtHardwareDevice { - - /** The hardware device types. */ - // Must be updated in concert with the native OrtHardwareDeviceType enum in the C API - public enum OrtHardwareDeviceType { - /** A CPU device. */ - CPU(0), - /** A GPU device. */ - GPU(1), - /** A NPU (Neural Processing Unit) device. */ - NPU(2); - private final int value; - - private static final Logger logger = Logger.getLogger(OrtHardwareDeviceType.class.getName()); - private static final OrtHardwareDeviceType[] values = new OrtHardwareDeviceType[3]; - - static { - for (OrtHardwareDeviceType ot : OrtHardwareDeviceType.values()) { - values[ot.value] = ot; - } - } - - OrtHardwareDeviceType(int value) { - this.value = value; - } - - /** - * Gets the native value associated with this device type. - * - * @return The native value. - */ - public int getValue() { - return value; - } - - /** - * Maps from the C API's int enum to the Java enum. - * - * @param deviceType The index of the Java enum. - * @return The Java enum. - */ - public static OrtHardwareDeviceType mapFromInt(int deviceType) { - if ((deviceType >= 0) && (deviceType < values.length)) { - return values[deviceType]; - } else { - logger.warning("Unknown device type '" + deviceType + "' setting to CPU"); - return CPU; - } - } - } - - private final long nativeHandle; - - private final OrtHardwareDeviceType type; - private final int vendorId; - private final String vendor; - private final int deviceId; - private final Map metadata; - - OrtHardwareDevice(long nativeHandle) { - this.nativeHandle = nativeHandle; - this.type = - OrtHardwareDeviceType.mapFromInt(getDeviceType(OnnxRuntime.ortApiHandle, nativeHandle)); - this.vendorId = getVendorId(OnnxRuntime.ortApiHandle, nativeHandle); - this.vendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle); - this.deviceId = getDeviceId(OnnxRuntime.ortApiHandle, nativeHandle); - String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle); - this.metadata = OrtUtil.convertToMap(metadata); - } - - long getNativeHandle() { - return nativeHandle; - } - - /** - * Gets the device type. - * - * @return The device type. - */ - public OrtHardwareDeviceType getType() { - return type; - } - - /** - * Gets the vendor ID number. - * - * @return The vendor ID number. - */ - public int getVendorId() { - return vendorId; - } - - /** - * Gets the device ID number. - * - * @return The device ID number. - */ - public int getDeviceId() { - return deviceId; - } - - /** - * Gets an unmodifiable view on the device metadata. - * - * @return The device metadata. - */ - public Map getMetadata() { - return metadata; - } - - /** - * Gets the vendor name. - * - * @return The vendor name. - */ - public String getVendor() { - return vendor; - } - - @Override - public String toString() { - return "OrtHardwareDevice{" - + "type=" - + type - + ", vendorId=" - + vendorId - + ", vendor='" - + vendor - + '\'' - + ", deviceId=" - + deviceId - + ", metadata=" - + metadata - + '}'; - } - - private static native String getVendor(long apiHandle, long nativeHandle); - - private static native String[][] getMetadata(long apiHandle, long nativeHandle); - - private static native int getDeviceType(long apiHandle, long nativeHandle); - - private static native int getDeviceId(long apiHandle, long nativeHandle); - - private static native int getVendorId(long apiHandle, long nativeHandle); -} diff --git a/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java b/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java deleted file mode 100644 index 09b3064b72b93..0000000000000 --- a/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java +++ /dev/null @@ -1,280 +0,0 @@ -/* - * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. - * Licensed under the MIT License. - */ -package ai.onnxruntime; - -import java.nio.ByteBuffer; -import java.util.EnumSet; - -/** Configuration options for compiling ONNX models. */ -public final class OrtModelCompilationOptions implements AutoCloseable { - /** Flags representing options when compiling a model. */ - public enum OrtCompileApiFlags implements OrtFlags { - /** Default. Do not enable any additional compilation options. */ - NONE(0), - - /** - * Force compilation to return an error (ORT_FAIL) if no nodes were compiled. Otherwise, a model - * with basic optimizations (ORT_ENABLE_BASIC) is still generated by default. - */ - ERROR_IF_NO_NODES_COMPILED(1), - - /** - * Force compilation to return an error (ORT_FAIL) if a file with the same filename as the - * output model exists. Otherwise, compilation will automatically overwrite the output file if - * it exists. - */ - ERROR_IF_OUTPUT_FILE_EXISTS(1 << 1); - - /** The native value of the enum. */ - public final int value; - - OrtCompileApiFlags(int value) { - this.value = value; - } - - @Override - public int getValue() { - return value; - } - } - - private final long nativeHandle; - private boolean closed = false; - - // Used to ensure the byte buffer doesn't get GC'd before the model is compiled. - private ByteBuffer buffer; - - OrtModelCompilationOptions(long nativeHandle) { - this.nativeHandle = nativeHandle; - } - - /** - * Creates a model compilation options from an existing SessionOptions. - * - *

An OrtModelCompilationOptions object contains the settings used to generate a compiled ONNX - * model. The OrtSessionOptions object has the execution providers with which the model will be - * compiled. - * - * @param env The OrtEnvironment. - * @param sessionOptions The session options to use. - * @return A constructed model compilation options instance. - * @throws OrtException If the construction failed. - */ - public static OrtModelCompilationOptions createFromSessionOptions( - OrtEnvironment env, OrtSession.SessionOptions sessionOptions) throws OrtException { - long handle = - createFromSessionOptions( - OnnxRuntime.ortApiHandle, - OnnxRuntime.ortCompileApiHandle, - env.getNativeHandle(), - sessionOptions.getNativeHandle()); - return new OrtModelCompilationOptions(handle); - } - - /** - * Checks if the OrtModelCompilationOptions is closed, if so throws {@link IllegalStateException}. - */ - private void checkClosed() { - if (closed) { - throw new IllegalStateException("Trying to use a closed OrtModelCompilationOptions."); - } - } - - @Override - public void close() { - if (!closed) { - close(OnnxRuntime.ortCompileApiHandle, nativeHandle); - closed = true; - } else { - throw new IllegalStateException("Trying to close a closed OrtModelCompilationOptions."); - } - } - - /** - * Sets the file path to the input ONNX model. - * - *

The input model's location must be set either to a path on disk with this method, or by - * supplying an in-memory reference with {@link #setInputModelFromBuffer}. - * - * @param inputModelPath The path to the model on disk. - * @throws OrtException If the set failed. - */ - public void setInputModelPath(String inputModelPath) throws OrtException { - checkClosed(); - setInputModelPath( - OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, inputModelPath); - } - - /** - * Uses the supplied buffer as the input ONNX model. - * - *

The input model's location must be set either to an in-memory reference with this method, or - * by supplying a path on disk with {@link #setInputModelPath(String)}. - * - *

If the {@link ByteBuffer} is not direct it is copied into a direct buffer. In either case - * this object holds a reference to the buffer to prevent it from being GC'd. - * - * @param inputModelBuffer The buffer. - * @throws OrtException If the buffer could not be set. - */ - public void setInputModelFromBuffer(ByteBuffer inputModelBuffer) throws OrtException { - checkClosed(); - if (!inputModelBuffer.isDirect()) { - // if it's not a direct buffer, copy it. - buffer = ByteBuffer.allocateDirect(inputModelBuffer.remaining()); - int tmpPos = inputModelBuffer.position(); - buffer.put(inputModelBuffer); - buffer.rewind(); - inputModelBuffer.position(tmpPos); - } else { - buffer = inputModelBuffer; - } - int bufferPos = buffer.position(); - int bufferRemaining = buffer.remaining(); - setInputModelFromBuffer( - OnnxRuntime.ortApiHandle, - OnnxRuntime.ortCompileApiHandle, - nativeHandle, - buffer, - bufferPos, - bufferRemaining); - } - - /** - * Sets the file path for the output compiled ONNX model. - * - *

If this is unset it will append `_ctx` to the file name, e.g., my_model.onnx becomes - * my_model_ctx.onnx. - * - * @param outputModelPath The output model path. - * @throws OrtException If the path could not be set. - */ - public void setOutputModelPath(String outputModelPath) throws OrtException { - checkClosed(); - setOutputModelPath( - OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, outputModelPath); - } - - /** - * Optionally sets the file that stores initializers for the compiled ONNX model. If unset then - * initializers are stored inside the model. - * - *

Only initializers for nodes that were not compiled are stored in the external initializers - * file. Compiled nodes contain their initializer data within the `ep_cache_context` attribute of - * EPContext nodes. - * - * @see OrtModelCompilationOptions#setEpContextEmbedMode - * @param outputExternalInitializersPath Path to the file. - * @param sizeThreshold Initializers larger than this threshold are stored in the file. - * @throws OrtException If the path could not be set. - */ - public void setOutputExternalInitializersPath( - String outputExternalInitializersPath, long sizeThreshold) throws OrtException { - checkClosed(); - // check positive - setOutputExternalInitializersPath( - OnnxRuntime.ortApiHandle, - OnnxRuntime.ortCompileApiHandle, - nativeHandle, - outputExternalInitializersPath, - sizeThreshold); - } - - /** - * Enables or disables the embedding of EPContext binary data into the ep_cache_context attribute - * of EPContext nodes. - * - *

Defaults to false. When enabled, the `ep_cache_context` attribute of EPContext nodes will - * store the context binary data, which may include weights for compiled subgraphs. When disabled, - * the `ep_cache_context` attribute of EPContext nodes will contain the path to the file - * containing the context binary data. The path is set by the execution provider creating the - * EPContext node. - * - *

For more details see the EPContext design - * document. - * - * @param embedEpContext True to embed EPContext binary data into the EPContext node's - * ep_cache_context attribute. - * @throws OrtException If the set operation failed. - */ - public void setEpContextEmbedMode(boolean embedEpContext) throws OrtException { - checkClosed(); - setEpContextEmbedMode( - OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, embedEpContext); - } - - /** - * Sets the specified compilation flags. - * - * @param flags The compilation flags. - * @throws OrtException If the set operation failed. - */ - public void setCompilationFlags(EnumSet flags) throws OrtException { - checkClosed(); - setCompilationFlags( - OnnxRuntime.ortApiHandle, - OnnxRuntime.ortCompileApiHandle, - nativeHandle, - OrtFlags.aggregateToInt(flags)); - } - - /** - * Compiles the ONNX model with the configuration described by this instance of - * OrtModelCompilationOptions. - * - * @throws OrtException If the compilation failed. - */ - public void compileModel() throws OrtException { - checkClosed(); - // Safe as the environment must exist to create one of these objects. - OrtEnvironment env = OrtEnvironment.getEnvironment(); - compileModel( - OnnxRuntime.ortApiHandle, - OnnxRuntime.ortCompileApiHandle, - env.getNativeHandle(), - nativeHandle); - } - - private static native long createFromSessionOptions( - long apiHandle, long compileApiHandle, long envHandle, long nativeHandle) throws OrtException; - - private static native void close(long compileApiHandle, long nativeHandle); - - private static native void setInputModelPath( - long apiHandle, long compileApiHandle, long nativeHandle, String inputModelPath) - throws OrtException; - - private static native void setInputModelFromBuffer( - long apiHandle, - long compileApiHandle, - long nativeHandle, - ByteBuffer inputBuffer, - long bufferPos, - long bufferRemaining) - throws OrtException; - - private static native void setOutputModelPath( - long apiHandle, long compileApiHandle, long nativeHandle, String outputModelPath) - throws OrtException; - - private static native void setOutputExternalInitializersPath( - long apiHandle, - long compileApiHandle, - long nativeHandle, - String externalInitializersPath, - long sizeThreshold) - throws OrtException; - - private static native void setEpContextEmbedMode( - long apiHandle, long compileApiHandle, long nativeHandle, boolean embedEpContext) - throws OrtException; - - private static native void setCompilationFlags( - long apiHandle, long compileApiHandle, long nativeHandle, int flags) throws OrtException; - - private static native void compileModel( - long apiHandle, long compileApiHandle, long envHandle, long nativeHandle) throws OrtException; -} diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 42dc90b71cb80..a399d5080ca16 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates * Licensed under the MIT License. */ @@ -8,6 +8,7 @@ import ai.onnxruntime.providers.CoreMLFlags; import ai.onnxruntime.providers.NNAPIFlags; import ai.onnxruntime.providers.OrtCUDAProviderOptions; +import ai.onnxruntime.providers.OrtFlags; import ai.onnxruntime.providers.OrtTensorRTProviderOptions; import java.io.IOException; import java.nio.ByteBuffer; @@ -623,10 +624,6 @@ private native OnnxModelMetadata constructMetadata( *

Used to set the number of threads, optimisation level, computation backend and other * options. * - *

The order execution providers are added to an options instance is the order they will be - * considered for op node assignment, with the EP added first having priority. The CPU EP is a - * fallback and added by default. - * *

Modifying this after the session has been constructed will have no effect. * *

The SessionOptions object must not be closed until all sessions which use it are closed, as @@ -733,7 +730,7 @@ public SessionOptions() { @Override public void close() { if (!closed) { - if (!customLibraryHandles.isEmpty()) { + if (customLibraryHandles.size() > 0) { long[] longArray = new long[customLibraryHandles.size()]; for (int i = 0; i < customLibraryHandles.size(); i++) { longArray[i] = customLibraryHandles.get(i); @@ -920,10 +917,10 @@ public void registerCustomOpLibrary(String path) throws OrtException { * *

 OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); * - *

See Add - * Custom Op for more information on custom ops. See an example of a custom op library - * registration function here. + *

See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for more + * information on custom ops. See + * https://github.com/microsoft/onnxruntime/blob/342a5bf2b756d1a1fc6fdc582cfeac15182632fe/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc#L115 + * for an example of a custom op library registration function. * * @param registrationFuncName The name of the registration function to call. * @throws OrtException If there was an error finding or calling the registration function. @@ -1276,47 +1273,10 @@ public void addCoreML(EnumSet flags) throws OrtException { addCoreML(OnnxRuntime.ortApiHandle, nativeHandle, OrtFlags.aggregateToInt(flags)); } - /** - * Adds the specified execution provider and device tuples as an execution backend. - * - *

Execution provider priority is in the order added, i.e., the first provider added to a - * session options will be used first for op node assignment. - * - * @param devices The EP and device tuples. Each element must use the same EP, though they can - * use different devices. - * @param providerOptions Configuration options for the execution provider. Refer to the - * specific execution provider's documentation. - * @throws OrtException If there was an error in native code. - */ - public void addExecutionProvider(List devices, Map providerOptions) - throws OrtException { - checkClosed(); - if (devices.isEmpty()) { - throw new IllegalArgumentException("Must supply at least one OrtEpDevice"); - } - long[] deviceHandles = new long[devices.size()]; - for (int i = 0; i < devices.size(); i++) { - deviceHandles[i] = devices.get(i).getNativeHandle(); - } - String[][] optsArray = OrtUtil.unpackMap(providerOptions); - // This is valid as the environment must have been created to create the OrtEpDevice list. - long envHandle = OrtEnvironment.getEnvironment().getNativeHandle(); - addExecutionProvider( - OnnxRuntime.ortApiHandle, - envHandle, - nativeHandle, - deviceHandles, - optsArray[0], - optsArray[1]); - } - /** * Adds the named execution provider (backend) as an execution backend. This generic function * only allows a subset of execution providers. * - *

Execution provider priority is in the order added, i.e., the first provider added to a - * session options will be used first for op node assignment. - * * @param providerName The name of the execution provider. * @param providerOptions Configuration options for the execution provider. Refer to the * specific execution provider's documentation. @@ -1325,9 +1285,20 @@ public void addExecutionProvider(List devices, Map private void addExecutionProvider(String providerName, Map providerOptions) throws OrtException { checkClosed(); - String[][] optsArray = OrtUtil.unpackMap(providerOptions); + String[] providerOptionKey = new String[providerOptions.size()]; + String[] providerOptionVal = new String[providerOptions.size()]; + int i = 0; + for (Map.Entry entry : providerOptions.entrySet()) { + providerOptionKey[i] = entry.getKey(); + providerOptionVal[i] = entry.getValue(); + i++; + } addExecutionProvider( - OnnxRuntime.ortApiHandle, nativeHandle, providerName, optsArray[0], optsArray[1]); + OnnxRuntime.ortApiHandle, + nativeHandle, + providerName, + providerOptionKey, + providerOptionVal); } /** @@ -1513,15 +1484,6 @@ private native void addExecutionProvider( String[] providerOptionKey, String[] providerOptionVal) throws OrtException; - - private native void addExecutionProvider( - long apiHandle, - long envHandle, - long nativeHandle, - long[] deviceHandles, - String[] providerOptionKey, - String[] providerOptionVal) - throws OrtException; } /** Used to control logging and termination of a call to {@link OrtSession#run}. */ diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index ee91fdb292baa..2f44236e4ef67 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ @@ -16,9 +16,6 @@ import java.nio.ShortBuffer; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; import java.util.logging.Logger; /** Util code for interacting with Java arrays. */ @@ -373,52 +370,6 @@ public static boolean validateShape(long[] shape) { return valid && shape.length <= TensorInfo.MAX_DIMENSIONS; } - /** - * Converts the output of a OrtKeyValuePairs into a Java unmodifiable HashMap. - * - * @param zippedString The zipped keys and values. - * @return An unmodifiable Map. - */ - static Map convertToMap(String[][] zippedString) { - if (zippedString.length != 2) { - throw new IllegalArgumentException("Invalid zipped string, must have two arrays."); - } else if (zippedString[0].length != zippedString[1].length) { - throw new IllegalArgumentException( - "Invalid zipped string, must have two arrays of the same length."); - } - Map map = new HashMap<>(capacityFromSize(zippedString[0].length)); - for (int i = 0; i < zippedString[0].length; i++) { - map.put(zippedString[0][i], zippedString[1][i]); - } - return Collections.unmodifiableMap(map); - } - - /** - * Converts a Java string map into a pair of arrays suitable for constructing a native - * OrtKeyValuePairs object. - * - * @param map A map from string to string, with no null keys or values. - * @return A pair of String arrays. - */ - static String[][] unpackMap(Map map) { - String[] keys = new String[map.size()]; - String[] values = new String[map.size()]; - int i = 0; - for (Map.Entry entry : map.entrySet()) { - if (entry.getKey() == null || entry.getValue() == null) { - throw new IllegalArgumentException( - "Invalid map, keys and values must not be null, found key = " - + entry.getKey() - + ", value = " - + entry.getValue()); - } - keys[i] = entry.getKey(); - values[i] = entry.getValue(); - i++; - } - return new String[][] {keys, values}; - } - /** * Flatten a multidimensional String array into a single dimensional String array, reading it in a * multidimensional row-major order. diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index 15fe459dad7c8..22bf940844774 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,11 +1,9 @@ /* - * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; -import ai.onnxruntime.OrtFlags; - /** Flags for the CoreML provider. */ public enum CoreMLFlags implements OrtFlags { /** diff --git a/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java b/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java index dd30684078717..eeaf6cc8d53bc 100644 --- a/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java @@ -1,11 +1,9 @@ /* - * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; -import ai.onnxruntime.OrtFlags; - /** Flags for the NNAPI provider. */ public enum NNAPIFlags implements OrtFlags { /** Enables fp16 support. */ diff --git a/java/src/main/java/ai/onnxruntime/OrtFlags.java b/java/src/main/java/ai/onnxruntime/providers/OrtFlags.java similarity index 88% rename from java/src/main/java/ai/onnxruntime/OrtFlags.java rename to java/src/main/java/ai/onnxruntime/providers/OrtFlags.java index f57fd945dbeec..73d3eeae6499c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/OrtFlags.java @@ -1,8 +1,8 @@ /* - * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ -package ai.onnxruntime; +package ai.onnxruntime.providers; import java.util.EnumSet; diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 96ea8e79bc978..5d8efd7b476cb 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1014,36 +1014,6 @@ jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca } } -jobjectArray convertOrtKeyValuePairsToArrays(JNIEnv *jniEnv, const OrtApi * api, const OrtKeyValuePairs * kvp) { - // extract pair arrays - const char* const* keys = NULL; - const char* const* values = NULL; - size_t numKeys = 0; - api->GetKeyValuePairs(kvp, &keys, &values, &numKeys); - jsize jNumKeys = safecast_size_t_to_jsize(numKeys); - - // create Java String[] - jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); - jobjectArray keyArray = (*jniEnv)->NewObjectArray(jniEnv, jNumKeys, stringClazz, NULL); - jobjectArray valueArray = (*jniEnv)->NewObjectArray(jniEnv, jNumKeys, stringClazz, NULL); - - // populate Java arrays - for (jsize i = 0; i < jNumKeys; i++) { - jstring key = (*jniEnv)->NewStringUTF(jniEnv, keys[i]); - (*jniEnv)->SetObjectArrayElement(jniEnv, keyArray, i, key); - jstring value = (*jniEnv)->NewStringUTF(jniEnv, values[i]); - (*jniEnv)->SetObjectArrayElement(jniEnv, valueArray, i, value); - } - - // create Java String[][] - jclass stringArrClazz = (*jniEnv)->GetObjectClass(jniEnv, keyArray); - jobjectArray pair = (*jniEnv)->NewObjectArray(jniEnv, 2, stringArrClazz, 0); - (*jniEnv)->SetObjectArrayElement(jniEnv, pair, 0, keyArray); - (*jniEnv)->SetObjectArrayElement(jniEnv, pair, 1, valueArray); - - return pair; -} - jint throwOrtException(JNIEnv *jniEnv, int messageId, const char *message) { jstring messageStr = (*jniEnv)->NewStringUTF(jniEnv, message); diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 040fd41264c10..7f41e06371f2a 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -78,8 +78,6 @@ jobject createMapInfoFromValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* onnxValue); -jobjectArray convertOrtKeyValuePairsToArrays(JNIEnv *jniEnv, const OrtApi * api, const OrtKeyValuePairs * kvp); - jint throwOrtException(JNIEnv *env, int messageId, const char *message); jint convertErrorCode(OrtErrorCode code); diff --git a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c index d8f5f1a3cb2db..659f34e1fb66f 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c +++ b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c @@ -32,19 +32,6 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseTrainingAPIBas return (jlong) trainingApi; } -/* - * Class: ai_onnxruntime_OnnxRuntime - * Method: initialiseCompileAPIBase - * Signature: (J)J - */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseCompileAPIBase - (JNIEnv * jniEnv, jclass clazz, jlong apiHandle) { - (void)jniEnv; (void)clazz; // required JNI parameters not needed by functions which don't call back into Java. - const OrtApi* api = (const OrtApi*)apiHandle; - const OrtCompileApi* compileApi = api->GetCompileApi(); - return (jlong) compileApi; -} - /* * Class: ai_onnxruntime_OnnxRuntime * Method: getAvailableProviders diff --git a/java/src/main/native/ai_onnxruntime_OrtEnvironment.c b/java/src/main/native/ai_onnxruntime_OrtEnvironment.c index 77b096d62ec76..e1b1ff1c05fe1 100644 --- a/java/src/main/native/ai_onnxruntime_OrtEnvironment.c +++ b/java/src/main/native/ai_onnxruntime_OrtEnvironment.c @@ -60,76 +60,6 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtEnvironment_getDefaultAllocator return (jlong)allocator; } -/* - * Class: ai_onnxruntime_OrtEnvironment - * Method: registerExecutionProviderLibrary - * Signature: (JJLjava/lang/String;Ljava/lang/String;)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtEnvironment_registerExecutionProviderLibrary - (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle, jstring name, jstring libraryPath) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtEnv* env = (OrtEnv*) nativeHandle; - const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL); -#ifdef _WIN32 - const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, libraryPath, NULL); - size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, libraryPath); - wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); - if (newString == NULL) { - (*jniEnv)->ReleaseStringChars(jniEnv, libraryPath, cPath); - throwOrtException(jniEnv, 1, "Not enough memory"); - return; - } - wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); - checkOrtStatus(jniEnv, api, api->RegisterExecutionProviderLibrary(env, cName, newString)); - free(newString); - (*jniEnv)->ReleaseStringChars(jniEnv, libraryPath, cPath); -#else - const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, libraryPath, NULL); - checkOrtStatus(jniEnv, api, api->RegisterExecutionProviderLibrary(env, cName, cPath)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv, libraryPath, cPath); -#endif - (*jniEnv)->ReleaseStringUTFChars(jniEnv, name, cName); -} - -/* - * Class: ai_onnxruntime_OrtEnvironment - * Method: unregisterExecutionProviderLibrary - * Signature: (JJLjava/lang/String;)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtEnvironment_unregisterExecutionProviderLibrary - (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle, jstring name) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtEnv* env = (OrtEnv*) nativeHandle; - const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL); - checkOrtStatus(jniEnv, api, api->UnregisterExecutionProviderLibrary(env, cName)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv, name, cName); -} - -/* - * Class: ai_onnxruntime_OrtEnvironment - * Method: getEpDevices - * Signature: (JJ)[J - */ -JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OrtEnvironment_getEpDevices - (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtEnv* env = (OrtEnv*) nativeHandle; - size_t numDevices = 0; - const OrtEpDevice* const* devicesArr = NULL; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetEpDevices(env, &devicesArr, &numDevices)); - if (code != ORT_OK) { - return NULL; - } else { - jsize numDevicesInt = safecast_size_t_to_jsize(numDevices); - jlongArray outputArr = (*jniEnv)->NewLongArray(jniEnv, numDevicesInt); - (*jniEnv)->SetLongArrayRegion(jniEnv, outputArr, 0, numDevicesInt, (jlong*)devicesArr); - return outputArr; - } -} - /* * Class: ai_onnxruntime_OrtEnvironment * Method: close diff --git a/java/src/main/native/ai_onnxruntime_OrtEpDevice.c b/java/src/main/native/ai_onnxruntime_OrtEpDevice.c deleted file mode 100644 index 5a1e3092b0fb9..0000000000000 --- a/java/src/main/native/ai_onnxruntime_OrtEpDevice.c +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. - * Licensed under the MIT License. - */ -#include -#include "onnxruntime/core/session/onnxruntime_c_api.h" -#include "OrtJniUtil.h" -#include "ai_onnxruntime_OrtEpDevice.h" - -/* - * Class: ai_onnxruntime_OrtEpDevice - * Method: getName - * Signature: (JJ)Ljava/lang/String; - */ -JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getName - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; - const char* name = api->EpDevice_EpName(epDevice); - jstring nameStr = (*jniEnv)->NewStringUTF(jniEnv, name); - return nameStr; -} - -/* - * Class: ai_onnxruntime_OrtEpDevice - * Method: getVendor - * Signature: (JJ)Ljava/lang/String; - */ -JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getVendor - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; - const char* vendor = api->EpDevice_EpVendor(epDevice); - jstring vendorStr = (*jniEnv)->NewStringUTF(jniEnv, vendor); - return vendorStr; -} - -/* - * Class: ai_onnxruntime_OrtEpDevice - * Method: getMetadata - * Signature: (JJ)[[Ljava/lang/String; - */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getMetadata - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; - const OrtKeyValuePairs* kvp = api->EpDevice_EpMetadata(epDevice); - jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); - return pair; -} - -/* - * Class: ai_onnxruntime_OrtEpDevice - * Method: getOptions - * Signature: (JJ)[[Ljava/lang/String; - */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getOptions - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; - const OrtKeyValuePairs* kvp = api->EpDevice_EpOptions(epDevice); - jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); - return pair; -} - -/* - * Class: ai_onnxruntime_OrtEpDevice - * Method: getDeviceHandle - * Signature: (JJ)J - */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtEpDevice_getDeviceHandle - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { - (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; - const OrtHardwareDevice* device = api->EpDevice_Device(epDevice); - return (jlong) device; -} diff --git a/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c b/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c deleted file mode 100644 index 3191a89c26ba1..0000000000000 --- a/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. - * Licensed under the MIT License. - */ -#include -#include "onnxruntime/core/session/onnxruntime_c_api.h" -#include "OrtJniUtil.h" -#include "ai_onnxruntime_OrtHardwareDevice.h" - -/* - * Class: ai_onnxruntime_OrtHardwareDevice - * Method: getVendor - * Signature: (JJ)Ljava/lang/String; - */ -JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getVendor - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; - const char* vendor = api->HardwareDevice_Vendor(device); - jstring vendorStr = (*jniEnv)->NewStringUTF(jniEnv, vendor); - return vendorStr; -} - -/* - * Class: ai_onnxruntime_OrtHardwareDevice - * Method: getMetadata - * Signature: (JJ)[[Ljava/lang/String; - */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getMetadata - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; - const OrtKeyValuePairs* kvp = api->HardwareDevice_Metadata(device); - jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); - return pair; -} - -/* - * Class: ai_onnxruntime_OrtHardwareDevice - * Method: getDeviceType - * Signature: (JJ)I - */ -JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getDeviceType - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; - OrtHardwareDeviceType type = api->HardwareDevice_Type(device); - jint output = 0; - // Must be kept aligned with the Java OrtHardwareDeviceType enum. - switch (type) { - case OrtHardwareDeviceType_CPU: - output = 0; - break; - case OrtHardwareDeviceType_GPU: - output = 1; - break; - case OrtHardwareDeviceType_NPU: - output = 2; - break; - default: - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Unexpected device type found. Only CPU, GPU and NPU are supported."); - break; - } - return output; -} - -/* - * Class: ai_onnxruntime_OrtHardwareDevice - * Method: getDeviceId - * Signature: (JJ)I - */ -JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getDeviceId - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { - (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; - uint32_t id = api->HardwareDevice_DeviceId(device); - return (jint) id; -} - -/* - * Class: ai_onnxruntime_OrtHardwareDevice - * Method: getVendorId - * Signature: (JJ)I - */ -JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getVendorId - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { - (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; - uint32_t id = api->HardwareDevice_VendorId(device); - return (jint) id; -} diff --git a/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c b/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c deleted file mode 100644 index 4f79383d09766..0000000000000 --- a/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. - * Licensed under the MIT License. - */ -#include -#include "onnxruntime/core/session/onnxruntime_c_api.h" -#include "OrtJniUtil.h" -#include "ai_onnxruntime_OrtModelCompilationOptions.h" - -/* - * Class: ai_onnxruntime_OrtModelCompilationOptions - * Method: createFromSessionOptions - * Signature: (JJJJ)J - */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_createFromSessionOptions - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong envHandle, jlong sessionOptionsHandle) { - (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; - const OrtEnv* env = (const OrtEnv*)envHandle; - const OrtSessionOptions* sessionOptions = (const OrtSessionOptions*) sessionOptionsHandle; - OrtModelCompilationOptions* output = NULL; - checkOrtStatus(jniEnv, api, compileApi->CreateModelCompilationOptionsFromSessionOptions(env, sessionOptions, &output)); - return (jlong) output; -} - -/* - * Class: ai_onnxruntime_OrtModelCompilationOptions - * Method: close - * Signature: (JJ)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_close - (JNIEnv * jniEnv, jclass jclazz, jlong compileApiHandle, jlong nativeHandle) { - (void)jniEnv; (void)jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. - const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; - compileApi->ReleaseModelCompilationOptions((OrtModelCompilationOptions *)nativeHandle); -} - -/* - * Class: ai_onnxruntime_OrtModelCompilationOptions - * Method: setInputModelPath - * Signature: (JJJLjava/lang/String;)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setInputModelPath - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring modelPath) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; - OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; -#ifdef _WIN32 - const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL); - size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath); - wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); - if (newString == NULL) { - (*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath); - throwOrtException(jniEnv, 1, "Not enough memory"); - return; - } - wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); - checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelPath(compOpts, newString)); - free(newString); - (*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath); -#else - const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL); - checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelPath(compOpts, cPath)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv, modelPath, cPath); -#endif -} - -/* - * Class: ai_onnxruntime_OrtModelCompilationOptions - * Method: setInputModelFromBuffer - * Signature: (JJJLjava/nio/ByteBuffer;JJ)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setInputModelFromBuffer - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jobject buffer, jlong bufferPos, jlong bufferRemaining) { - (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - // Cast to pointers - const OrtApi* api = (const OrtApi*)apiHandle; - const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; - OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; - - // Extract the buffer - char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); - // Increment by bufferPos bytes - bufferArr = bufferArr + bufferPos; - checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelFromBuffer(compOpts, bufferArr, bufferRemaining)); -} - -/* - * Class: ai_onnxruntime_OrtModelCompilationOptions - * Method: setOutputModelPath - * Signature: (JJJLjava/lang/String;)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setOutputModelPath - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring outputPath) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; - OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; -#ifdef _WIN32 - const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, outputPath, NULL); - size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, outputPath); - wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); - if (newString == NULL) { - (*jniEnv)->ReleaseStringChars(jniEnv, outputPath, cPath); - throwOrtException(jniEnv, 1, "Not enough memory"); - return; - } - wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); - checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelPath(compOpts, newString)); - free(newString); - (*jniEnv)->ReleaseStringChars(jniEnv, outputPath, cPath); -#else - const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, outputPath, NULL); - checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelPath(compOpts, cPath)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv, outputPath, cPath); -#endif -} - -/* - * Class: ai_onnxruntime_OrtModelCompilationOptions - * Method: setOutputExternalInitializersPath - * Signature: (JJJLjava/lang/String;J)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setOutputExternalInitializersPath - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring initializersPath, jlong threshold) { - (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; - OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; -#ifdef _WIN32 - const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, initializersPath, NULL); - size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, initializersPath); - wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); - if (newString == NULL) { - (*jniEnv)->ReleaseStringChars(jniEnv, initializersPath, cPath); - throwOrtException(jniEnv, 1, "Not enough memory"); - return; - } - wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); - checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelExternalInitializersFile(compOpts, newString, threshold)); - free(newString); - (*jniEnv)->ReleaseStringChars(jniEnv, initializersPath, cPath); -#else - const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, initializersPath, NULL); - checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelExternalInitializersFile(compOpts, cPath, threshold)); - (*jniEnv)->ReleaseStringUTFChars(jniEnv, initializersPath, cPath); -#endif -} - -/* - * Class: ai_onnxruntime_OrtModelCompilationOptions - * Method: setEpContextEmbedMode - * Signature: (JJJZ)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setEpContextEmbedMode - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jboolean embedMode) { - (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; - OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; - checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetEpContextEmbedMode(compOpts, (bool) embedMode)); -} - -/* - * Class: ai_onnxruntime_OrtModelCompilationOptions - * Method: setCompilationFlags - * Signature: (JJJI)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setCompilationFlags - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jint flags) { - (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; - OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; - checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetFlags(compOpts, flags)); -} - -/* - * Class: ai_onnxruntime_OrtModelCompilationOptions - * Method: compileModel - * Signature: (JJJJ)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_compileModel - (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong envHandle, jlong nativeHandle) { - (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*)apiHandle; - const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; - const OrtEnv* env = (const OrtEnv*)envHandle; - OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; - checkOrtStatus(jniEnv, api, compileApi->CompileModel(env, compOpts)); -} diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 95bcdf7af9746..ff6b7fa703e6e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -718,11 +718,11 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addROC } /* - * Class: ai_onnxruntime_OrtSession_SessionOptions + * Class:: ai_onnxruntime_OrtSession_SessionOptions * Method: addExecutionProvider - * Signature: (JJLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;)V + * Signature: (JILjava/lang/String)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider__JJLjava_lang_String_2_3Ljava_lang_String_2_3Ljava_lang_String_2( +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider( JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring jepName, jobjectArray configKeyArr, jobjectArray configValueArr) { (void)jobj; @@ -756,50 +756,3 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExe free((void*)jkeyArray); free((void*)jvalueArray); } - -/* - * Class: ai_onnxruntime_OrtSession_SessionOptions - * Method: addExecutionProvider - * Signature: (JJJ[J[Ljava/lang/String;[Ljava/lang/String;)V - */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider__JJJ_3J_3Ljava_lang_String_2_3Ljava_lang_String_2 - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jlong optionsHandle, jlongArray deviceHandleArr, jobjectArray configKeyArr, jobjectArray configValueArr) { - (void)jobj; - - const OrtApi* api = (const OrtApi*)apiHandle; - OrtEnv* env = (OrtEnv*) envHandle; - OrtSessionOptions* options = (OrtSessionOptions*)optionsHandle; - jsize deviceCount = (*jniEnv)->GetArrayLength(jniEnv, deviceHandleArr); - jsize keyCount = (*jniEnv)->GetArrayLength(jniEnv, configKeyArr); - - const char** keyArray = (const char**)allocarray(keyCount, sizeof(const char*)); - const char** valueArray = (const char**)allocarray(keyCount, sizeof(const char*)); - jstring* jkeyArray = (jstring*)allocarray(keyCount, sizeof(jstring)); - jstring* jvalueArray = (jstring*)allocarray(keyCount, sizeof(jstring)); - const OrtEpDevice** devicePtrs = allocarray(deviceCount, sizeof(OrtEpDevice *)); - - jlong* deviceHandleElements = (*jniEnv)->GetLongArrayElements(jniEnv, deviceHandleArr, NULL); - for (jsize i = 0; i < deviceCount; i++) { - devicePtrs[i] = (OrtEpDevice*) deviceHandleElements[i]; - } - (*jniEnv)->ReleaseLongArrayElements(jniEnv, deviceHandleArr, deviceHandleElements, JNI_ABORT); - - for (jsize i = 0; i < keyCount; i++) { - jkeyArray[i] = (jstring)((*jniEnv)->GetObjectArrayElement(jniEnv, configKeyArr, i)); - jvalueArray[i] = (jstring)((*jniEnv)->GetObjectArrayElement(jniEnv, configValueArr, i)); - keyArray[i] = (*jniEnv)->GetStringUTFChars(jniEnv, jkeyArray[i], NULL); - valueArray[i] = (*jniEnv)->GetStringUTFChars(jniEnv, jvalueArray[i], NULL); - } - - checkOrtStatus(jniEnv, api, api->SessionOptionsAppendExecutionProvider_V2(options, env, devicePtrs, deviceCount, keyArray, valueArray, keyCount)); - - for (jsize i = 0; i < keyCount; i++) { - (*jniEnv)->ReleaseStringUTFChars(jniEnv, jkeyArray[i], keyArray[i]); - (*jniEnv)->ReleaseStringUTFChars(jniEnv, jvalueArray[i], valueArray[i]); - } - free((void*)devicePtrs); - free((void*)keyArray); - free((void*)valueArray); - free((void*)jkeyArray); - free((void*)jvalueArray); -} diff --git a/java/src/test/java/ai/onnxruntime/CompileApiTest.java b/java/src/test/java/ai/onnxruntime/CompileApiTest.java deleted file mode 100644 index b70f4dca5cbd0..0000000000000 --- a/java/src/test/java/ai/onnxruntime/CompileApiTest.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. - * Licensed under the MIT License. - */ -package ai.onnxruntime; - -import ai.onnxruntime.OrtSession.SessionOptions; -import java.io.File; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.file.Files; -import java.nio.file.Path; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -/** Test for the compilation API. */ -public class CompileApiTest { - private final OrtEnvironment env = OrtEnvironment.getEnvironment(); - - @Test - public void basicUsage() throws OrtException, IOException { - SessionOptions so = new SessionOptions(); - try (OrtModelCompilationOptions compileOptions = - OrtModelCompilationOptions.createFromSessionOptions(env, so)) { - // mainly checking these don't throw which ensures all the plumbing for the binding works. - compileOptions.setInputModelPath("model.onnx"); - compileOptions.setOutputModelPath("compiled_model.onnx"); - - compileOptions.setOutputExternalInitializersPath("external_data.bin", 512); - compileOptions.setEpContextEmbedMode(true); - } - - try (OrtModelCompilationOptions compileOptions = - OrtModelCompilationOptions.createFromSessionOptions(env, so)) { - Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); - byte[] modelBytes = Files.readAllBytes(modelPath); - ByteBuffer modelBuffer = ByteBuffer.wrap(modelBytes); - compileOptions.setInputModelFromBuffer(modelBuffer); - compileOptions.setOutputModelPath("compiled_model.onnx"); - - File f = new File("compiled_model.onnx"); - - compileOptions.compileModel(); - - // Check the compiled model is valid - try (OrtSession session = env.createSession(f.toString(), so)) { - Assertions.assertNotNull(session); - } - - f.delete(); - } - } -} diff --git a/java/src/test/java/ai/onnxruntime/EpDeviceTest.java b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java deleted file mode 100644 index ec4c977508c8c..0000000000000 --- a/java/src/test/java/ai/onnxruntime/EpDeviceTest.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. - * Licensed under the MIT License. - */ -package ai.onnxruntime; - -import ai.onnxruntime.OrtHardwareDevice.OrtHardwareDeviceType; -import ai.onnxruntime.OrtSession.SessionOptions; -import java.io.File; -import java.nio.file.Path; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledOnOs; -import org.junit.jupiter.api.condition.OS; - -/** Tests for {@link OrtEpDevice} and {@link OrtHardwareDevice}. */ -@EnabledOnOs(value = OS.WINDOWS) -public class EpDeviceTest { - private final OrtEnvironment ortEnv = OrtEnvironment.getEnvironment(); - - private void readHardwareDeviceValues(OrtHardwareDevice device) { - OrtHardwareDeviceType type = device.getType(); - - Assertions.assertTrue( - type == OrtHardwareDeviceType.CPU - || type == OrtHardwareDeviceType.GPU - || type == OrtHardwareDeviceType.NPU); - - if (type == OrtHardwareDeviceType.CPU) { - Assertions.assertFalse(device.getVendor().isEmpty()); - } else { - Assertions.assertTrue(device.getVendorId() != 0); - Assertions.assertTrue(device.getDeviceId() != 0); - } - - Map metadata = device.getMetadata(); - Assertions.assertNotNull(metadata); - for (Map.Entry kvp : metadata.entrySet()) { - Assertions.assertFalse(kvp.getKey().isEmpty()); - } - } - - @Test - public void getEpDevices() throws OrtException { - List epDevices = ortEnv.getEpDevices(); - Assertions.assertNotNull(epDevices); - Assertions.assertFalse(epDevices.isEmpty()); - for (OrtEpDevice epDevice : epDevices) { - Assertions.assertFalse(epDevice.getName().isEmpty()); - Assertions.assertFalse(epDevice.getVendor().isEmpty()); - Map metadata = epDevice.getMetadata(); - Assertions.assertNotNull(metadata); - Map options = epDevice.getOptions(); - Assertions.assertNotNull(options); - readHardwareDeviceValues(epDevice.getDevice()); - } - } - - @Test - public void registerUnregisterLibrary() throws OrtException { - String libFullPath = TestHelpers.getResourcePath("/example_plugin_ep.dll").toString(); - Assertions.assertTrue( - new File(libFullPath).exists(), "Expected lib " + libFullPath + " does not exist."); - - // example plugin ep uses the registration name as the ep name - String epName = "java_ep"; - - // register. shouldn't throw - ortEnv.registerExecutionProviderLibrary(epName, libFullPath); - - // check OrtEpDevice was found - List epDevices = ortEnv.getEpDevices(); - boolean found = epDevices.stream().anyMatch(a -> a.getName().equals(epName)); - Assertions.assertTrue(found); - - // unregister - ortEnv.unregisterExecutionProviderLibrary(epName); - } - - @Test - public void appendToSessionOptionsV2() { - Consumer>> runTest = - (Supplier> options) -> { - try (SessionOptions sessionOptions = new SessionOptions()) { - sessionOptions.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE); - - List epDevices = ortEnv.getEpDevices(); - - // cpu ep ignores the provider options so we can use any value in epOptions and it won't - // break. - List selectedEpDevices = - epDevices.stream() - .filter(a -> a.getName().equals("CPUExecutionProvider")) - .collect(Collectors.toList()); - - Map epOptions = options.get(); - sessionOptions.addExecutionProvider(selectedEpDevices, epOptions); - - Path model = TestHelpers.getResourcePath("/squeezenet.onnx"); - String modelPath = model.toString(); - - // session should load successfully - try (OrtSession session = ortEnv.createSession(modelPath, sessionOptions)) { - Assertions.assertNotNull(session); - } - } catch (OrtException e) { - throw new RuntimeException(e); - } - }; - - // empty options - runTest.accept(Collections::emptyMap); - - // dummy options - runTest.accept(() -> Collections.singletonMap("random_key", "value")); - } -} diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 8db91f792cb06..84ed3457a488b 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -15,7 +15,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { // create ONNX runtime env Ort::InitApi(); ORT_NAPI_THROW_ERROR_IF( - &Ort::GetApi() == nullptr, env, + Ort::Global::api_ == nullptr, env, "Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version " "ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library)."); diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index bfa450f4287f8..0d5117709c18a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -280,18 +280,6 @@ class GQAAttentionBase { output, static_cast(present_buffer_sequence_length), nullptr); } - // Pre-allocate buffer for attention mask to avoid allocating it for every processed token - float* attention_bias_thread_fp32 = nullptr; - if (attention_bias_thread != nullptr) { - if constexpr (!std::is_same_v) { - static_assert(std::is_same_v && std::is_same_v); - - size_t bytes = attention_total_seqlen * sizeof(float); - attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); - } - } - BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); - // compute Softmax U* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { @@ -328,6 +316,9 @@ class GQAAttentionBase { static_cast(window_size)); } else { static_assert(std::is_same_v && std::is_same_v); + size_t bytes = window_size * sizeof(float); + auto attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); + BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); MlasConvertHalfToFloatBuffer(attention_bias_thread + start_offset, attention_bias_thread_fp32, window_size); ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread_fp32, static_cast(window_size)); diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 5c6c3b919b572..9b35a40f64f2a 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -331,13 +331,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t token_idx = route_idx / k_; const float weight = route_scale[route_idx]; - const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); - if (buffer_offset + static_cast(hidden_size) > output_buffer_size) { - // Skip this token to prevent buffer overflow - continue; - } - - float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; + float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + token_idx * hidden_size; const float* src = C2 + i * hidden_size; for (int64_t j = 0; j < hidden_size; ++j) { dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); @@ -350,9 +344,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { auto accumulate = [&](float* buffer) { memset(buffer, 0, output_buffer_size * sizeof(float)); for (int i = 0; i < num_expert_threads; ++i) { - const size_t thread_offset = static_cast(i) * output_buffer_size; for (size_t j = 0; j < output_buffer_size; ++j) { - buffer[j] += thread_local_outputs[thread_offset + j]; + buffer[j] += thread_local_outputs[static_cast(i) * output_buffer_size + j]; } } }; diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 36a6f70cc69d9..85a2cbaea0e44 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -200,19 +200,6 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available); - // Kleidi dynamic path requires strictly positive, finite scales. - // Disable if any invalid scale is detected. - if (can_use_dynamic_quant_mlas_) { - const auto bs = b_scale_tensor->DataAsSpan(); - const bool has_invalid = - std::any_of(bs.begin(), bs.end(), - [](float s) { return !std::isfinite(s) || s <= 0.0f; }); - - if (has_invalid) { - can_use_dynamic_quant_mlas_ = false; - } - } - // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. // We check that here too before attempting to use them. if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { @@ -392,7 +379,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { if (y->Shape().Size() == 0) return Status::OK(); - const float* a_data = ctx->Input(IN_A)->Data(); + auto a_data = static_cast(ctx->Input(IN_A)->DataRaw()); auto* y_data = y->MutableData(); // batch gemm @@ -406,7 +393,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) { auto& params = gemm_data_vec[gemm_idx]; - params.A = a_data + helper.LeftOffsets()[gemm_idx]; + params.A = reinterpret_cast(a_data + helper.LeftOffsets()[gemm_idx]); params.lda = gemm_shape.K; params.PackedB = packed_b_.get(); params.C = y_data + helper.OutputOffsets()[gemm_idx]; diff --git a/onnxruntime/core/common/cpuid_arch_definition.h b/onnxruntime/core/common/cpuid_arch_definition.h index 5946b8ca27067..a541eb66d8ba3 100644 --- a/onnxruntime/core/common/cpuid_arch_definition.h +++ b/onnxruntime/core/common/cpuid_arch_definition.h @@ -9,6 +9,6 @@ #define CPUIDINFO_ARCH_X86 #endif -#if defined(_M_ARM64) || defined(_M_ARM64EC) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) +#if defined(_M_ARM64) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) #define CPUIDINFO_ARCH_ARM #endif // ARM or ARM64 diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 2ef7c4a9091f3..b99c22edb36c8 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -252,6 +252,16 @@ struct OrtNode { /// A status indicating success or an error. virtual onnxruntime::Status GetAttributes(gsl::span attrs) const = 0; + ///

+ /// Gets the node's 'TENSOR' attribute as an OrtValue. + /// + /// Node's 'TENSOR' attribute. + /// Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value, + /// only if the attribute is of type 'TENSOR' + /// A status indicating success or an error. + virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, + OrtValue*& value) const = 0; + /// /// Gets the number of node subgraphs. /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 92eb31f0ad385..759a2998ace3a 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -249,6 +249,32 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } +Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const { + const auto* attr_proto = reinterpret_cast(attribute); + + if (attr_proto->type() != onnx::AttributeProto::TENSOR) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); + } + + const auto& graph_viewer = ep_graph_->GetGraphViewer(); + const auto& tensor_proto = attr_proto->t(); + + // Check that TensorProto is valid. + ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type."); + ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type."); + ORT_ENFORCE(!utils::HasExternalData(tensor_proto), + "Tensor proto with external data for value attribute is not supported."); + + // Initialize OrtValue for tensor attribute. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + result = tensor_attribute_value.release(); + return Status::OK(); +} + Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { num_subgraphs = subgraphs_.size(); return Status::OK(); @@ -327,9 +353,6 @@ static Status GetInputIndices(const EpNode& consumer_node, [&found, &value_info_name, &indices](gsl::span input_value_infos, bool is_implicit) -> void { for (size_t i = 0; i < input_value_infos.size(); i++) { - if (input_value_infos[i] == nullptr) { // input_value_info == nullptr means the input is optional - continue; - } if (input_value_infos[i]->GetName() == value_info_name) { indices.push_back(is_implicit ? -1 : static_cast(i)); found = true; diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index e003f02a79a2d..7f22e265129f7 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -183,6 +183,9 @@ struct EpNode : public OrtNode { // Gets the node's attributes. Status GetAttributes(gsl::span attrs) const override; + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, + OrtValue*& attr_tensor) const override; + // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 2c0f6d6174303..e7ffcbc7e4c90 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -138,6 +138,11 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); + } + Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index c579ff1542eb9..caa445b71e2a5 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -153,23 +153,28 @@ ArmKleidiAI::MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - if (M == 0 || N == 0) { - return true; + if(TransA == CblasTrans) + { + return false; } - - if (Data->alpha == 0.0f || K == 0) { - if (Data->beta == 0.0f) { - for (size_t i = 0; i < M; ++i) { - std::fill_n(Data->C + i * Data->ldc, N, 0.0f); - } - } else if (Data->beta != 1.0f) { + if (TransA == CblasNoTrans && K == 0) { + if (Data->beta != 1.0f) { for (size_t i = 0; i < M; ++i) { for (size_t j = 0; j < N; ++j) { Data->C[i * Data->ldc + j] *= Data->beta; } } } - return true; + } + if (Data->beta == 0.0f){ + std::fill_n(Data->C, M * Data->ldc, 0.0f); + } + //Fallback in the case of unsupported cases + if (M == 0 || N == 0 || K == 0 || + TransA != CblasNoTrans || + (TransB != CblasNoTrans && !Data[0].BIsPacked)) + { + return false; } if (TransA == CblasNoTrans) { @@ -180,9 +185,11 @@ ArmKleidiAI::MlasGemmBatch( auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - if (M < m_step && N < n_step && !Data->BIsPacked) { - // Fallback to MLAS - return false; + if (M < m_step || N < n_step) { + if (GetMlasPlatform().MlasGemmBatchOverride != ArmKleidiAI::MlasGemmBatch){ + //Fallback to MLAS + return false; + } } std::vector KaiPackedData; @@ -309,7 +316,7 @@ ArmKleidiAI::MlasGemmBatch( float* dst_tile = reinterpret_cast(CTile); // quick copy of data in cases where we are not scaling or accumulating anything - // with bounds checking on tile sizing to ensure the data fits in the memory block + // with bounds checking on tile sizing to ensure the data fits in the memory block bool can_memcpy = ( Data[BIdx].alpha == 1.0f && Data[BIdx].beta == 0.0f && @@ -321,37 +328,21 @@ ArmKleidiAI::MlasGemmBatch( if (can_memcpy) { std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); - return; - } + }else { + // apply alpha scaling and beta to output files + for (size_t i = 0; i < TileSizeM; ++i) { + for (size_t j = 0; j < TileSizeN; ++j) { + const size_t idx = i * TileSizeN + j; + const size_t dst_idx = i * Data[BIdx].ldc + j; - float alpha = Data[BIdx].alpha; - float beta = Data[BIdx].beta; - size_t ldc = Data[BIdx].ldc; - - for (size_t i = 0; i < TileSizeM; ++i) { - for (size_t j = 0; j < TileSizeN; ++j) { - const size_t temp_idx = i * TileSizeN + j; - const size_t dst_idx = i * ldc + j; - - float ab = temp_tile[temp_idx]; - float c_orig = dst_tile[dst_idx]; - - if (alpha == 1.0f && beta == 0.0f) { - dst_tile[dst_idx] = ab; - } else if (alpha == 1.0f) { - dst_tile[dst_idx] = ab + beta * c_orig; - } else if (beta == 0.0f) { - dst_tile[dst_idx] = alpha * ab; - } else { - dst_tile[dst_idx] = alpha * ab + beta * c_orig; + float ab = temp_tile[idx]; + float c_orig = dst_tile[dst_idx]; + + dst_tile[dst_idx] = Data[BIdx].alpha * ab + Data[BIdx].beta * c_orig; } } } - return; }); - return true; - } - else { - return false; } + return true; } diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 06c3628eb301d..4bcf71335d15e 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1266,16 +1266,17 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe // the single operator operation mode of CANN if (info_.enable_cann_graph) { std::vector&& unsupported_nodes = SupportONNXModel(graph_viewer); - if (info_.enable_cann_subgraph && !unsupported_nodes.empty()) { + + if (unsupported_nodes.empty()) { + auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer); + result.push_back(ComputeCapability::Create(std::move(sub_graph))); + } else { auto partitions = GetSubGraphPartition(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); for (const auto& partition : partitions) { auto sub_graph = GetSubGraph(partition, graph_viewer); result.push_back(ComputeCapability::Create(std::move(sub_graph))); } - } else { - auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); } } else { InlinedVector candidates; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider_info.cc b/onnxruntime/core/providers/cann/cann_execution_provider_info.cc index d6cf9fad70ae5..d1ba7544bc09e 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider_info.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider_info.cc @@ -20,7 +20,6 @@ constexpr const char* kDeviceId = "device_id"; constexpr const char* kMemLimit = "npu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kEnableCannGraph = "enable_cann_graph"; -constexpr const char* kEnableCannSubGraph = "enable_cann_subgraph"; constexpr const char* kDumpGraphs = "dump_graphs"; constexpr const char* kDumpOmModel = "dump_om_model"; constexpr const char* kPrecisionMode = "precision_mode"; @@ -59,7 +58,6 @@ CANNExecutionProviderInfo CANNExecutionProviderInfo::FromProviderOptions(const P cann::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) .AddAssignmentToReference(cann::provider_option_names::kEnableCannGraph, info.enable_cann_graph) - .AddAssignmentToReference(cann::provider_option_names::kEnableCannSubGraph, info.enable_cann_subgraph) .AddAssignmentToReference(cann::provider_option_names::kDumpGraphs, info.dump_graphs) .AddAssignmentToReference(cann::provider_option_names::kDumpOmModel, info.dump_om_model) .AddAssignmentToReference(cann::provider_option_names::kPrecisionMode, info.precision_mode) @@ -76,7 +74,6 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const CANNExecution {cann::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}, - {cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)}, {cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)}, {cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)}, {cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)}, @@ -92,7 +89,6 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const OrtCANNProvid {cann::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, ArenaExtendStrategy(info.arena_extend_strategy))}, {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}, - {cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)}, {cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)}, {cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)}, {cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)}, diff --git a/onnxruntime/core/providers/cann/cann_execution_provider_info.h b/onnxruntime/core/providers/cann/cann_execution_provider_info.h index 9c1f9eb03b67e..7ac43e9a8ed6f 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider_info.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider_info.h @@ -18,7 +18,6 @@ struct CANNExecutionProviderInfo { size_t npu_mem_limit{std::numeric_limits::max()}; ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; bool enable_cann_graph{true}; - bool enable_cann_subgraph{false}; bool dump_graphs{false}; bool dump_om_model{true}; std::string precision_mode; diff --git a/onnxruntime/core/providers/cann/cann_provider_factory.cc b/onnxruntime/core/providers/cann/cann_provider_factory.cc index d3dc86f588f1d..4a130b9b0ca20 100644 --- a/onnxruntime/core/providers/cann/cann_provider_factory.cc +++ b/onnxruntime/core/providers/cann/cann_provider_factory.cc @@ -76,7 +76,6 @@ struct CANN_Provider : Provider { info.npu_mem_limit = params->npu_mem_limit; info.arena_extend_strategy = params->arena_extend_strategy; info.enable_cann_graph = params->enable_cann_graph != 0; - info.enable_cann_subgraph = params->enable_cann_subgraph != 0; info.dump_graphs = params->dump_graphs != 0; info.dump_om_model = params->dump_om_model != 0; info.precision_mode = params->precision_mode; @@ -95,7 +94,6 @@ struct CANN_Provider : Provider { cann_options.npu_mem_limit = internal_options.npu_mem_limit; cann_options.arena_extend_strategy = internal_options.arena_extend_strategy; cann_options.enable_cann_graph = internal_options.enable_cann_graph; - cann_options.enable_cann_subgraph = internal_options.enable_cann_subgraph; cann_options.dump_graphs = internal_options.dump_graphs; cann_options.dump_om_model = internal_options.dump_om_model; cann_options.precision_mode = internal_options.precision_mode; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 93b673f2df5bd..b7997ce86737a 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -20,6 +20,7 @@ #include "onnx_ctx_model_helper.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/cuda_graph.h" +#include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" #include "core/common/parse_string.h" @@ -84,6 +85,40 @@ struct ShutdownProtobuf { namespace onnxruntime { +namespace cuda { +template <> +void Impl_Cast( + cudaStream_t stream, + const int64_t* input_data, int32_t* output_data, + size_t count) { + return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); +} + +template <> +void Impl_Cast( + cudaStream_t stream, + const int32_t* input_data, int64_t* output_data, + size_t count) { + return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); +} + +template <> +void Impl_Cast( + cudaStream_t stream, + const double* input_data, float* output_data, + size_t count) { + return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); +} + +template <> +void Impl_Cast( + cudaStream_t stream, + const float* input_data, double* output_data, + size_t count) { + return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); +} +} // namespace cuda + void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr @@ -337,19 +372,51 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(); \ + skip_input_binding_allowed = false; \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + data = scratch_buffers.back().get(); \ + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + #define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ case DATA_TYPE: { \ auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ data_ptr = output_tensor_ptr; \ if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - buffer = output_tensor_ptr; \ + buffers[output_name] = output_tensor_ptr; \ } else { \ scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - buffer = scratch_buffers.back().get(); \ + buffers[output_name] = scratch_buffers.back().get(); \ } \ break; \ } +#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + data_ptr = output_tensor_ptr; \ + skip_output_binding_allowed = false; \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = static_cast(elem_cnt); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = 1; \ + } \ + break; \ + } + #define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ case DATA_TYPE: { \ auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ @@ -359,6 +426,15 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ + } \ + break; \ + } + /* * Set Nv executio context input. * @@ -481,6 +557,7 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); @@ -505,6 +582,8 @@ Status BindContextInput(Ort::KernelContext& ctx, * param output_type - Data type of the output * param i - Output iteration index * param output_tensors - Output iteration index to output's ORT value + * param output_dim_sizes - Output iteration index to the multiplocation of its shape's dimensions + * param dds_output_set - DDS output set * param dds_output_allocator_map - DDS output to its allocator * param scratch_buffer - The allocation buffer created by TRT EP * param allocator - ORT allocator @@ -516,11 +595,16 @@ Status BindContextOutput(Ort::KernelContext& ctx, const char* output_name, size_t output_index, size_t output_type, + size_t i, + std::unordered_map& output_tensors, + std::unordered_map& output_dim_sizes, DDSOutputAllocatorMap& dds_output_allocator_map, std::vector>& scratch_buffers, OrtAllocator* alloc, + std::unordered_map& buffers, nvinfer1::Dims& dims, - void*& data_ptr) { + void*& data_ptr, + bool& skip_output_binding_allowed) { // Get output shape dims = trt_context->getTensorShape(output_name); int nb_dims = dims.nbDims; @@ -550,11 +634,10 @@ Status BindContextOutput(Ort::KernelContext& ctx, data_ptr = nullptr; // Set data_ptr to nullptr for DDS output binding. } } else { - auto output_tensor = ctx.GetOutput(output_index, dims.d, nb_dims); + output_tensors[i] = ctx.GetOutput(output_index, dims.d, nb_dims); + auto& output_tensor = output_tensors[i]; const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); - void* buffer = nullptr; - switch (output_type) { // below macros set data_ptr and skip_output_binding_allowed variables CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) @@ -565,12 +648,13 @@ Status BindContextOutput(Ort::KernelContext& ctx, CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } - trt_context->setTensorAddress(output_name, buffer); + trt_context->setTensorAddress(output_name, buffers[output_name]); } return Status::OK(); @@ -627,6 +711,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx, CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) + CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -2752,6 +2837,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } // Save TRT engine, other TRT objects and input/output info to map + parsers_.emplace(fused_node.Name(), std::move(trt_parser)); engines_.emplace(fused_node.Name(), std::move(trt_engine)); contexts_.emplace(fused_node.Name(), std::move(trt_context)); networks_.emplace(fused_node.Name(), std::move(trt_network)); @@ -2767,7 +2853,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), - &engines_[context->node_name], &contexts_[context->node_name], + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, engine_cache_enable_, cache_path_, @@ -2805,6 +2891,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; + int num_outputs = static_cast(output_indexes.size()); std::unordered_set input_names; if (alloc_ == nullptr) { @@ -2879,7 +2966,16 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr /* * Set output shapes and bind output buffers */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + if (require_io_binding) { + bool skip_output_binding_allowed = true; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -2897,15 +2993,16 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr nvinfer1::Dims dims; void* data_ptr = nullptr; - - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, - dds_output_allocator_map, scratch_buffers, alloc, dims, data_ptr); + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, + dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims}; } + + trt_state->skip_io_binding_allowed = trt_state->skip_io_binding_allowed | skip_output_binding_allowed; } // Set execution context memory @@ -2985,6 +3082,14 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } + } else { + auto& output_tensor = output_tensors[i]; + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } + } } } @@ -3108,6 +3213,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); + int num_outputs = static_cast(output_indexes.size()); std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input @@ -3177,7 +3283,16 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra /* * Set output shapes and bind output buffers */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + if (require_io_binding) { + bool skip_output_binding_allowed = true; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3196,14 +3311,16 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra nvinfer1::Dims dims; void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, - dds_output_allocator_map, scratch_buffers, alloc, dims, data_ptr); + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, + dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims}; } + + trt_state->skip_io_binding_allowed = trt_state->skip_io_binding_allowed | skip_output_binding_allowed; } // Set execution context memory @@ -3284,6 +3401,14 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } + } else { + auto& output_tensor = output_tensors[i]; + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); + } + } } } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 9e5fd03756f02..22b8314649757 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -195,6 +195,7 @@ struct TensorrtFuncState { AllocatorHandle allocator = nullptr; std::string fused_node_name; nvinfer1::IBuilder* builder; + tensorrt_ptr::unique_pointer* parser = nullptr; std::unique_ptr* engine = nullptr; std::unique_ptr* context = nullptr; std::unique_ptr* network = nullptr; @@ -385,6 +386,7 @@ class NvExecutionProvider : public IExecutionProvider { // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. + std::unordered_map> parsers_; std::unordered_map> engines_; std::unordered_map> contexts_; std::unordered_map> builders_; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index a994c936970f6..541ca5ca7ab14 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -245,12 +245,6 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, bool is_graph_input = qnn_model_wrapper.IsGraphInput(input1_name); LOGS(logger, VERBOSE) << "Add HWCN Transpose node after input: " << input1_name; - if (!qnn_model_wrapper.IsQnnTensorWrapperExist(input1_name)) { - QnnTensorWrapper weight_tensor_wrapper; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], weight_tensor_wrapper)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor_wrapper)), "Failed to add weight tensor."); - } - if (conv_type == OnnxConvType::kConv) { ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddNchwToHwcnTranspose(node_unit.Index(), input1_name, @@ -431,7 +425,7 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, // // Input 1: weight - // We need to first reshape the weight in order to handle 1D convolutions with the Conv2d operator. + // We need to first reshape the weight inorder to handle 1D convolutions with the Conv2d operator. // Next, we have to transpose the weight because ORT layout transformations do not change the weight layout. // { @@ -517,12 +511,6 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF(input_info.quant_param.IsPerChannel(), "Non-constant Conv inputs only support per-tensor quantization"); - if (!qnn_model_wrapper.IsQnnTensorWrapperExist(input1_name)) { - QnnTensorWrapper weight_tensor_wrapper; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], weight_tensor_wrapper)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor_wrapper)), "Failed to add weight tensor."); - } - bool is_graph_input = qnn_model_wrapper.IsGraphInput(input1_name); LOGS(logger, VERBOSE) << "Adding Reshape (to 2D) and HWCN Transpose node after input: " << input1_name; ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(input1_name, diff --git a/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc b/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc index f8d88b07f6dd5..9fa2551e53c23 100644 --- a/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc +++ b/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc @@ -24,7 +24,7 @@ std::once_flag init; } // namespace void InitProviderOrtApi() { - std::call_once(init, []() { Ort::InitApi(Provider_GetHost()->OrtGetApiBase()->GetApi(ORT_API_VERSION)); }); + std::call_once(init, []() { Ort::Global::api_ = Provider_GetHost()->OrtGetApiBase()->GetApi(ORT_API_VERSION); }); } -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 580fbfbdba0b0..5fc0b8900730b 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -229,7 +229,7 @@ int vitisai_ep_set_ep_dynamic_options( struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { op_kernel_ = - op_.CreateKernel(&op_, &Ort::GetApi(), reinterpret_cast(&info)); + op_.CreateKernel(&op_, Ort::Global::api_, reinterpret_cast(&info)); } ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } @@ -332,8 +332,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { InitProviderOrtApi(); set_version_info(the_global_api); the_global_api.host_ = Provider_GetHost(); - assert(&Ort::GetApi() != nullptr); - the_global_api.ort_api_ = &Ort::GetApi(); + assert(Ort::Global::api_ != nullptr); + the_global_api.ort_api_ = Ort::Global::api_; the_global_api.model_load = [](const std::string& filename) -> Model* { auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); auto& logger = logging::LoggingManager::DefaultLogger(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f3e2a8ce7ba7b..ad0a1ad137f06 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3036,7 +3036,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { API_IMPL_BEGIN if (attr_tensor == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); @@ -3045,39 +3045,7 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetTensorAttributeAsOrtValue, _In_ const Ort return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); } - const auto* attr_proto = reinterpret_cast(attribute); - - if (attr_proto->type() != onnx::AttributeProto::TENSOR) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "This OrtOpAttr instance is not a 'TENSOR' attribute"); - } - - const auto& tensor_proto = attr_proto->t(); - - // Check that TensorProto is valid. - if (!utils::HasDataType(tensor_proto)) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto doesn't have data type."); - } - - if (!ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto has invalid data type."); - } - - if (utils::HasExternalData(tensor_proto)) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, - "Tensor proto with external data for value attribute is not supported."); - } - - // Initialize OrtValue for tensor attribute. - auto tensor_attribute_value = std::make_unique(); - AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); - // The tensor in the 'Tensor' attribute's TensorProto is stored inline, not in an external file. - // Therefore, the 'model_path' passed to TensorProtoToOrtValue() may be an empty path. - std::filesystem::path model_path; - ORT_API_RETURN_IF_STATUS_NOT_OK(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto, - tensor_attribute_allocator, *tensor_attribute_value)); - - *attr_tensor = tensor_attribute_value.release(); - + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); return nullptr; API_IMPL_END } @@ -4166,7 +4134,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumAttributes, &OrtApis::Node_GetAttributes, &OrtApis::Node_GetAttributeByName, - &OrtApis::OpAttr_GetTensorAttributeAsOrtValue, + &OrtApis::Node_GetTensorAttributeAsOrtValue, &OrtApis::OpAttr_GetType, &OrtApis::OpAttr_GetName, &OrtApis::Node_GetNumSubgraphs, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 6dc4cf9d195cc..e62149d04a16c 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -687,7 +687,7 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_result_maybenull_ const OrtOpAttr** attribute); -ORT_API_STATUS_IMPL(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, +ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc index 42b65239de92c..d6e51a44c1c69 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc @@ -4,8 +4,6 @@ #include "core/session/plugin_ep/ep_factory_provider_bridge.h" #include "core/providers/shared_library/provider_host_api.h" -#include "core/session/plugin_ep/ep_library_plugin.h" -#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" namespace onnxruntime { OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, @@ -22,11 +20,6 @@ OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_fa auto* ep_device = ep_devices[i]; if (ep_device) { ep_device->ep_factory = &ep_factory; - - // Add library path to EP metadata if available - if (library_path_.has_value()) { - ep_device->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_LibraryPath, library_path_->string()); - } } } diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 8c5ef526baba1..437af62dc2c0c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -3,10 +3,6 @@ #pragma once -#include -#include -#include - #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" @@ -16,14 +12,12 @@ namespace onnxruntime { class ProviderBridgeEpFactory : public EpFactoryInternalImpl { public: - ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library, - std::optional library_path = std::nullopt) + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), ep_factory.GetVendor(&ep_factory), ep_factory.GetVendorId(&ep_factory)), ep_factory_{ep_factory}, - provider_library_{provider_library}, - library_path_{std::move(library_path)} { + provider_library_{provider_library} { } private: @@ -65,9 +59,8 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); } - OrtEpFactory& ep_factory_; - ProviderLibrary& provider_library_; - std::optional library_path_; + OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP + ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h index af5bc23143e33..24ab74e1c77fc 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library.h +++ b/onnxruntime/core/session/plugin_ep/ep_library.h @@ -23,7 +23,6 @@ class EpLibrary { virtual Status Load() { return Status::OK(); } virtual const std::vector& GetFactories() = 0; // valid after Load() virtual Status Unload() { return Status::OK(); } - virtual ~EpLibrary() = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(EpLibrary); diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc index da94a9f12ba9d..06cf54aea4071 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc @@ -4,7 +4,6 @@ #include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/plugin_ep/ep_factory_provider_bridge.h" -#include "core/session/plugin_ep/ep_library_plugin.h" namespace onnxruntime { Status EpLibraryProviderBridge::Load() { @@ -27,9 +26,8 @@ Status EpLibraryProviderBridge::Load() { // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. - for (const auto& factory : ep_library_plugin_->GetFactories()) { - auto factory_impl = std::make_unique(*factory, *provider_library_, library_path_); + auto factory_impl = std::make_unique(*factory, *provider_library_); auto internal_factory = std::make_unique(std::move(factory_impl)); factory_ptrs_.push_back(internal_factory.get()); diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h index 45277b2828f56..c7e8ebefc3785 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h @@ -21,11 +21,9 @@ namespace onnxruntime { class EpLibraryProviderBridge : public EpLibrary { public: EpLibraryProviderBridge(std::unique_ptr provider_library, - std::unique_ptr ep_library_plugin, - std::optional library_path = std::nullopt) + std::unique_ptr ep_library_plugin) : provider_library_{std::move(provider_library)}, - ep_library_plugin_{std::move(ep_library_plugin)}, - library_path_{std::move(library_path)} { + ep_library_plugin_{std::move(ep_library_plugin)} { } const char* RegistrationName() const override { @@ -55,9 +53,6 @@ class EpLibraryProviderBridge : public EpLibrary { // implement EpFactoryInternal::CreateIExecutionProvider by calling Provider::CreateIExecutionProvider. std::unique_ptr ep_library_plugin_; - // Library path for EP metadata - std::optional library_path_; - std::vector> factories_; std::vector factory_ptrs_; // for convenience std::vector internal_factory_ptrs_; // for convenience diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index f82cbcf63ca62..41cf8be1d1412 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2902,7 +2902,6 @@ ORT_API_STATUS_IMPL(OrtApis::CreateCANNProviderOptions, _Outptr_ OrtCANNProvider options->npu_mem_limit = SIZE_MAX; options->arena_extend_strategy = static_cast(0); options->enable_cann_graph = 1; - options->enable_cann_subgraph = 0; options->dump_graphs = 0; options->dump_om_model = 1; options->default_memory_arena_cfg = nullptr; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 7da7fabb15b15..d4041dfce5a7a 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -421,14 +421,13 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, << (is_provider_bridge ? " as a provider bridge" : " as a plugin"); // create EpLibraryPlugin to ensure CreateEpFactories and ReleaseEpFactory are available - auto ep_library_plugin = std::make_unique(registration_name, resolved_library_path); + auto ep_library_plugin = std::make_unique(registration_name, std::move(resolved_library_path)); ORT_RETURN_IF_ERROR(ep_library_plugin->Load()); if (is_provider_bridge) { // wrap the EpLibraryPlugin with EpLibraryProviderBridge to add to directly create an IExecutionProvider auto ep_library_provider_bridge = std::make_unique(std::move(provider_library), - std::move(ep_library_plugin), - resolved_library_path); + std::move(ep_library_plugin)); ORT_RETURN_IF_ERROR(ep_library_provider_bridge->Load()); internal_factories = ep_library_provider_bridge->GetInternalFactories(); ep_library = std::move(ep_library_provider_bridge); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index eb06a65ad5330..24554560b4dde 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1575,17 +1575,6 @@ void addGlobalMethods(py::module& m) { R"pbdoc(Get the list of available OrtEpDevice instances.)pbdoc", py::return_value_policy::reference); - m.def( - "get_model_compatibility_for_ep_devices", - [](const std::vector& ep_devices, - const std::string& compatibility_info) -> OrtCompiledModelCompatibility { - OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; - Ort::ThrowOnError(Ort::GetApi().GetModelCompatibilityForEpDevices( - ep_devices.data(), ep_devices.size(), compatibility_info.c_str(), &status)); - return status; - }, - R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc"); - #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( "get_available_openvino_device_ids", []() -> std::vector { @@ -1770,12 +1759,6 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED) .value("MEMORY_EFFICIENT", ExecutionOrder::MEMORY_EFFICIENT); - py::enum_(m, "OrtCompiledModelCompatibility") - .value("EP_NOT_APPLICABLE", OrtCompiledModelCompatibility_EP_NOT_APPLICABLE) - .value("EP_SUPPORTED_OPTIMAL", OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL) - .value("EP_SUPPORTED_PREFER_RECOMPILATION", OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION) - .value("EP_UNSUPPORTED", OrtCompiledModelCompatibility_EP_UNSUPPORTED); - py::enum_(m, "OrtAllocatorType") .value("INVALID", OrtInvalidAllocator) .value("ORT_DEVICE_ALLOCATOR", OrtDeviceAllocator) @@ -1799,7 +1782,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra type = OrtDevice::GPU; vendor = OrtDevice::VendorIds::MICROSOFT; } else if (type == OrtDevice::GPU) { -#if USE_CUDA || USE_NV || USE_NV_PROVIDER_INTERFACE || USE_CUDA_PROVIDER_INTERFACE +#if USE_CUDA vendor = OrtDevice::VendorIds::NVIDIA; #elif USE_ROCM || USE_MIGRAPHX vendor = OrtDevice::VendorIds::AMD; diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index a12aca47f5b65..191edc4c6390d 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -6,15 +6,15 @@ from __future__ import annotations import logging -import tempfile from pathlib import Path import onnx -from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed, optimize_model +from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed from ....tools.remove_initializer_from_input import remove_initializer_from_input from ...fusions import FusionGelu, FusionLayerNormalization from ...onnx_model import ONNXModel +from ...quant_utils import save_and_reload_model_with_shape_infer from .fusion_lpnorm import FusionLpNormalization from .fusion_spacetodepth import FusionSpaceToDepth @@ -93,7 +93,7 @@ def qnn_preprocess_model( """ modified = False model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input) - model = save_and_reload_optimize_model(model, shape_infer=True) + model = save_and_reload_model_with_shape_infer(model) onnx_model = ONNXModel(model) # Optionally, fix the dynamic input shapes. @@ -178,24 +178,6 @@ def qnn_preprocess_model( return modified -def save_and_reload_optimize_model(model: onnx.ModelProto, shape_infer: bool) -> onnx.ModelProto: - with tempfile.TemporaryDirectory(prefix="ort.qnn_preproc.") as qnn_preproc_tmp_dir: - model_in_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_input.onnx") - onnx.save_model(model, model_in_path, save_as_external_data=True) - if shape_infer: - model_infer_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_infer.onnx") - onnx.shape_inference.infer_shapes_path(str(model_in_path), str(model_infer_path)) - model_in_path = model_infer_path - model_out_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_output.onnx") - optimize_model(model_in_path, model_out_path) - ret_model = onnx.load_model(model_out_path) - ret_metaprops = {"onnx.infer": "onnxruntime.tools.qnn.preprocess"} - if ret_model.metadata_props: - ret_metaprops.update(ret_model.metadata_props) - onnx.helper.set_model_props(ret_model, ret_metaprops) - return ret_model - - class InputOutputNameMap: def __init__( self, diff --git a/onnxruntime/test/autoep/library/ep_arena.h b/onnxruntime/test/autoep/library/ep_arena.h index caa2c61db835f..641f3ce3f7b17 100644 --- a/onnxruntime/test/autoep/library/ep_arena.h +++ b/onnxruntime/test/autoep/library/ep_arena.h @@ -21,10 +21,7 @@ limitations under the License. #include #include -#define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT - #include "ep_allocator.h" #include "example_plugin_ep_utils.h" diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 0690b8894eb7a..ed7ca998e0b86 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -144,12 +144,6 @@ static void RunQMoETest(const std::vector& input, const std::vector("k", static_cast(top_k)); cpu_tester.AddAttribute("activation_type", activation_type); @@ -1329,13 +1323,6 @@ TEST(MoETest, QMoETest_Mixtral_Int4) { // CPU-specific QMoE tests TEST(MoETest, QMoETest_CPU_Int4_MLAS) { -#ifdef USE_MLAS - // Skip this test if we're not testing CPU execution provider - auto cpu_ep = DefaultCpuExecutionProvider(); - if (!cpu_ep) { - GTEST_SKIP() << "CPU execution provider not available"; - } - int num_rows = 2; int num_experts = 2; int hidden_size = 32; @@ -1400,19 +1387,9 @@ TEST(MoETest, QMoETest_CPU_Int4_MLAS) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); -#else - GTEST_SKIP() << "Skipping CPU QMoE test"; -#endif } TEST(MoETest, QMoETest_CPU_Int8_MLAS) { -#ifdef USE_MLAS - // Skip this test if we're not testing CPU execution provider - auto cpu_ep = DefaultCpuExecutionProvider(); - if (!cpu_ep) { - GTEST_SKIP() << "CPU execution provider not available"; - } - // Test CPU implementation with 8-bit quantization - CPU ONLY int num_rows = 1; int num_experts = 2; @@ -1469,19 +1446,9 @@ TEST(MoETest, QMoETest_CPU_Int8_MLAS) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); -#else - GTEST_SKIP() << "Skipping CPU QMoE test"; -#endif } TEST(MoETest, QMoETest_CPU_FC3_Error) { -#ifdef USE_MLAS - // Skip this test if we're not testing CPU execution provider - auto cpu_ep = DefaultCpuExecutionProvider(); - if (!cpu_ep) { - GTEST_SKIP() << "CPU execution provider not available"; - } - // Test that CPU throws error when FC3 gating is provided - CPU ONLY int num_rows = 1; int num_experts = 2; @@ -1539,19 +1506,9 @@ TEST(MoETest, QMoETest_CPU_FC3_Error) { // Expect this to fail with FC3 not implemented error cpu_tester.Run(OpTester::ExpectResult::kExpectFailure, "FC3 gating is not yet implemented", {}, nullptr, &cpu_execution_providers); -#else - GTEST_SKIP() << "Skipping CPU QMoE test"; -#endif } TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { -#ifdef USE_MLAS - // Skip this test if we're not testing CPU execution provider - auto cpu_ep = DefaultCpuExecutionProvider(); - if (!cpu_ep) { - GTEST_SKIP() << "CPU execution provider not available"; - } - // Test CPU implementation with 4-bit quantization and SwiGLU activation int num_rows = 2; int num_experts = 2; @@ -1616,18 +1573,9 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); -#else - GTEST_SKIP() << "Skipping CPU QMoE test"; -#endif } TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { -#ifdef USE_MLAS - // Skip this test if we're not testing CPU execution provider - auto cpu_ep = DefaultCpuExecutionProvider(); - if (!cpu_ep) { - GTEST_SKIP() << "CPU execution provider not available"; - } // Test CPU implementation with 8-bit quantization and SwiGLU activation int num_rows = 1; int num_experts = 2; @@ -1685,9 +1633,6 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); -#else - GTEST_SKIP() << "Skipping CPU QMoE test"; -#endif } #endif diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc index a8a83fbe5ceb6..ee82d4683ab73 100644 --- a/onnxruntime/test/framework/ep_compatibility_test.cc +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -15,7 +15,6 @@ #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/utils.h" #include "core/session/onnxruntime_c_api.h" -#include "core/session/onnxruntime_cxx_api.h" #include "core/session/abi_session_options_impl.h" #include "core/framework/error_code_helper.h" #include "dummy_provider.h" @@ -500,31 +499,3 @@ TEST(EpCompatibilityCapiTest, CpuEpReturnsNotApplicableIfNoValidation) { api->ReleaseEnv(env); } - -// ----------------------------- -// C++ API unit tests -// ----------------------------- - -TEST(EpCompatibilityCxxApiTest, SingleDeviceCpuProvider) { - Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpCompatCxx"}; - auto devices = env.GetEpDevices(); - ASSERT_FALSE(devices.empty()); - - std::vector selected; - for (const auto& d : devices) { - if (std::string{d.EpName()} == "CPUExecutionProvider") { - selected.push_back(d); - break; - } - } - - ASSERT_FALSE(selected.empty()); - - // Pick a status that the CPU EP would never return to ensure the value is set correctly. - OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; - ASSERT_NO_FATAL_FAILURE({ - status = Ort::GetModelCompatibilityForEpDevices(selected, "arbitrary-compat-string"); - }); - - ASSERT_TRUE(status == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); -} \ No newline at end of file diff --git a/onnxruntime/test/platform/device_discovery_test.cc b/onnxruntime/test/platform/device_discovery_test.cc index 6b43ccbc8f670..21ddf9a5b1cd7 100644 --- a/onnxruntime/test/platform/device_discovery_test.cc +++ b/onnxruntime/test/platform/device_discovery_test.cc @@ -25,9 +25,9 @@ TEST(DeviceDiscoveryTest, HasCpuDevice) { const auto cpu_devices = GetDevicesByType(OrtHardwareDeviceType_CPU); ASSERT_GT(cpu_devices.size(), 0); -#if defined(CPUINFO_SUPPORTED) +#if !defined(__wasm__) ASSERT_NE(cpu_devices[0].vendor_id, 0); -#endif // defined(CPUINFO_SUPPORTED) +#endif // !defined(__WASM__) } } // namespace onnxruntime::test diff --git a/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py b/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py deleted file mode 100644 index 8e69fdf088103..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -import platform -import sys -import unittest - -from onnxruntime.capi.onnxruntime_pybind11_state import ( - OrtCompiledModelCompatibility, - get_ep_devices, - get_model_compatibility_for_ep_devices, -) - -# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. -if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 - os.add_dll_directory(os.getcwd()) - - -class TestEpCompatibility(unittest.TestCase): - def test_invalid_args(self): - # empty devices - with self.assertRaises(RuntimeError): - get_model_compatibility_for_ep_devices([], "info") - # None compatibility info should raise TypeError before native call - with self.assertRaises(TypeError): - get_model_compatibility_for_ep_devices(get_ep_devices(), None) # type: ignore[arg-type] - - def test_basic_smoke(self): - devices = list(get_ep_devices()) - if not devices: - self.skipTest("No EP devices available in this build") - - # Always select CPUExecutionProvider; skip if not present. - cpu_devices = [d for d in devices if getattr(d, "ep_name", None) == "CPUExecutionProvider"] - if not cpu_devices: - self.skipTest("CPUExecutionProvider not available in this build") - selected = [cpu_devices[0]] - - # API requires all devices belong to the same EP; we pass only one. - status = get_model_compatibility_for_ep_devices(selected, "arbitrary-compat-string") - self.assertEqual(status, OrtCompiledModelCompatibility.EP_NOT_APPLICABLE) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py b/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py deleted file mode 100644 index d5c80a4a1f4ba..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py +++ /dev/null @@ -1,468 +0,0 @@ -# Copyright (c) NVIDIA Corporation. All rights reserved. -# Licensed under the MIT License. -from __future__ import annotations - -import sys -import unittest -from collections.abc import Sequence - -import numpy as np -import torch -from autoep_helper import AutoEpTestCase -from helper import get_name -from numpy.testing import assert_almost_equal -from onnx import TensorProto, helper -from onnx.defs import onnx_opset_version - -import onnxruntime as onnxrt -from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice -from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue -from onnxruntime.capi._pybind_state import OrtValueVector, SessionIOBinding - - -class TestNvTensorRTRTXAutoEP(AutoEpTestCase): - """ - Test suite for the NvTensorRTRTX Execution Provider. - - This class contains tests for registering the NvTensorRTRTX EP, - selecting it using different policies, and running inference with various - I/O binding configurations. - """ - - ep_lib_path = "onnxruntime_providers_nv_tensorrt_rtx.dll" - ep_name = "NvTensorRTRTXExecutionProvider" - - def setUp(self): - if sys.platform != "win32": - self.skipTest("Skipping test because device discovery is only supported on Windows") - self.register_execution_provider_library(self.ep_name, self.ep_lib_path) - - def tearDown(self): - self.unregister_execution_provider_library(self.ep_name) - - def _create_ortvalue_input_on_gpu(self, device): - return onnxrt.OrtValue.ortvalue_from_numpy( - np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), device, 0 - ) - - def _create_ortvalue_alternate_input_on_gpu(self, device): - return onnxrt.OrtValue.ortvalue_from_numpy( - np.array([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]], dtype=np.float32), - device, - 0, - ) - - def _create_uninitialized_ortvalue_input_on_gpu(self, device): - return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, device, 0) - - def _create_numpy_input(self): - return np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - - def _create_expected_output(self): - return np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) - - def _create_expected_output_alternate(self): - return np.array([[2.0, 8.0], [18.0, 32.0], [50.0, 72.0]], dtype=np.float32) - - def torch_to_onnx_type(self, torch_dtype): - if torch_dtype == torch.float32: - return TensorProto.FLOAT - elif torch_dtype == torch.float16: - return TensorProto.FLOAT16 - elif torch_dtype == torch.bfloat16: - return TensorProto.BFLOAT16 - elif torch_dtype == torch.int8: - return TensorProto.int8 - elif torch_dtype == torch.int32: - return TensorProto.INT32 - elif torch_dtype == torch.int64: - return TensorProto.INT64 - else: - raise TypeError(f"Unsupported dtype: {torch_dtype}") - - def test_nv_tensorrt_rtx_ep_register_and_inference(self): - """ - Test registration of NvTensorRTRTX EP, adding its OrtDevice to the SessionOptions, and running inference. - """ - ep_devices = onnxrt.get_ep_devices() - nv_tensorrt_rtx_ep_device = next((d for d in ep_devices if d.ep_name == self.ep_name), None) - self.assertIsNotNone(nv_tensorrt_rtx_ep_device) - self.assertEqual(nv_tensorrt_rtx_ep_device.ep_vendor, "NVIDIA") - - hw_device = nv_tensorrt_rtx_ep_device.device - self.assertEqual(hw_device.type, onnxrt.OrtHardwareDeviceType.GPU) - - # Run sample model and check output - sess = onnxrt.InferenceSession(get_name("mul_1.onnx")) - - x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - input_name = sess.get_inputs()[0].name - res = sess.run([], {input_name: x}) - output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) - np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) - - def test_nv_tensorrt_rtx_ep_prefer_gpu_and_inference(self): - """ - Test selecting NvTensorRTRTX EP via the PREFER_GPU policy and running inference. - """ - # Set a policy to prefer GPU. NvTensorRTRTX should be selected. - sess_options = onnxrt.SessionOptions() - sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) - self.assertTrue(sess_options.has_providers()) - - # Run sample model and check output - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) - - x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - input_name = sess.get_inputs()[0].name - res = sess.run([], {input_name: x}) - output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) - np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) - - def test_nv_tensorrt_rtx_ep_selection_delegate_and_inference(self): - """ - Test selecting NvTensorRTRTX EP via the custom EP selection delegate function and then run inference. - """ - - # User's custom EP selection function. - def my_delegate( - ep_devices: Sequence[onnxrt.OrtEpDevice], - model_metadata: dict[str, str], - runtime_metadata: dict[str, str], - max_selections: int, - ) -> Sequence[onnxrt.OrtEpDevice]: - self.assertGreater(len(model_metadata), 0) - self.assertGreaterEqual(len(ep_devices), 1) - self.assertGreaterEqual(max_selections, 2) - - nv_tensorrt_rtx_ep_device = next((d for d in ep_devices if d.ep_name == self.ep_name), None) - self.assertIsNotNone(nv_tensorrt_rtx_ep_device) - - # Select the NvTensorRTRTX device - return [nv_tensorrt_rtx_ep_device] - - sess_options = onnxrt.SessionOptions() - sess_options.set_provider_selection_policy_delegate(my_delegate) - self.assertTrue(sess_options.has_providers()) - - # Run sample model and check output - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) - - x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - input_name = sess.get_inputs()[0].name - res = sess.run([], {input_name: x}) - output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) - np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) - - def test_bind_input_only(self): - """ - Test I/O binding with input data only. - """ - # Set a policy to prefer GPU. NvTensorRTRTX should be selected. - sess_options = onnxrt.SessionOptions() - sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) - self.assertTrue(sess_options.has_providers()) - - input = self._create_ortvalue_input_on_gpu("cuda") - - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) - io_binding = session.io_binding() - - # Bind input to the GPU - io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) - - # Sync if different streams - io_binding.synchronize_inputs() - - # Bind output to CPU - io_binding.bind_output("Y") - - # Invoke Run - session.run_with_iobinding(io_binding) - - # Sync if different streams - io_binding.synchronize_outputs() - - # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host - # here) - ort_output = io_binding.copy_outputs_to_cpu()[0] - - # Validate results - self.assertTrue(np.array_equal(self._create_expected_output(), ort_output)) - - def test_bind_input_and_bind_output_with_ortvalues(self): - """ - Test I/O binding with OrtValues for both input and output. - """ - # Set a policy to prefer GPU. NvTensorRTRTX EP should be selected. - sess_options = onnxrt.SessionOptions() - sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) - self.assertTrue(sess_options.has_providers()) - - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) - io_binding = session.io_binding() - - # Bind ortvalue as input - input_ortvalue = self._create_ortvalue_input_on_gpu("cuda") - io_binding.bind_ortvalue_input("X", input_ortvalue) - - # Bind ortvalue as output - output_ortvalue = self._create_uninitialized_ortvalue_input_on_gpu("cuda") - io_binding.bind_ortvalue_output("Y", output_ortvalue) - - # Sync if different streams - io_binding.synchronize_inputs() - - # Invoke Run - session.run_with_iobinding(io_binding) - - # Sync if different streams - io_binding.synchronize_outputs() - - # Inspect contents of output_ortvalue and make sure that it has the right contents - self.assertTrue(np.array_equal(self._create_expected_output(), output_ortvalue.numpy())) - - # Bind another ortvalue as input - input_ortvalue_2 = self._create_ortvalue_alternate_input_on_gpu("cuda") - io_binding.bind_ortvalue_input("X", input_ortvalue_2) - - # Sync if different streams - io_binding.synchronize_inputs() - - # Invoke Run - session.run_with_iobinding(io_binding) - - # Sync if different streams - io_binding.synchronize_outputs() - - # Inspect contents of output_ortvalue and make sure that it has the right contents - self.assertTrue(np.array_equal(self._create_expected_output_alternate(), output_ortvalue.numpy())) - - def test_bind_input_and_non_preallocated_output(self): - """ - Test I/O binding with non-preallocated output. - """ - sess_options = onnxrt.SessionOptions() - sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) - self.assertTrue(sess_options.has_providers()) - - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) - io_binding = session.io_binding() - - input = self._create_ortvalue_input_on_gpu("cuda") - - # Bind input to the GPU - io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) - - # Bind output to the GPU - io_binding.bind_output("Y", "cuda") - - # Sync if different streams - io_binding.synchronize_inputs() - - # Invoke Run - session.run_with_iobinding(io_binding) - - # Sync if different streams - io_binding.synchronize_outputs() - - # This call returns an OrtValue which has data allocated by ORT on the GPU - ort_outputs = io_binding.get_outputs() - self.assertEqual(len(ort_outputs), 1) - self.assertEqual(ort_outputs[0].device_name(), "cuda") - # Validate results (by copying results to CPU by creating a Numpy object) - self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) - - # We should be able to repeat the above process as many times as we want - try once more - ort_outputs = io_binding.get_outputs() - self.assertEqual(len(ort_outputs), 1) - self.assertEqual(ort_outputs[0].device_name(), "cuda") - # Validate results (by copying results to CPU by creating a Numpy object) - self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) - - input = self._create_ortvalue_alternate_input_on_gpu("cuda") - - # Change the bound input and validate the results in the same bound OrtValue - # Bind alternate input to the GPU - io_binding.bind_input( - "X", - "cuda", - 0, - np.float32, - [3, 2], - input.data_ptr(), - ) - - # Sync if different streams - io_binding.synchronize_inputs() - - # Invoke Run - session.run_with_iobinding(io_binding) - - # Sync if different streams - io_binding.synchronize_outputs() - - # This call returns an OrtValue which has data allocated by ORT on the GPU - ort_outputs = io_binding.get_outputs() - self.assertEqual(len(ort_outputs), 1) - self.assertEqual(ort_outputs[0].device_name(), "cuda") - # Validate results (by copying results to CPU by creating a Numpy object) - self.assertTrue(np.array_equal(self._create_expected_output_alternate(), ort_outputs[0].numpy())) - - def test_bind_input_and_preallocated_output(self): - """ - Test I/O binding with preallocated output. - """ - sess_options = onnxrt.SessionOptions() - sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) - self.assertTrue(sess_options.has_providers()) - - input = self._create_ortvalue_input_on_gpu("cuda") - - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) - io_binding = session.io_binding() - - # Bind input to the GPU - io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) - - # Bind output to the GPU - output = self._create_uninitialized_ortvalue_input_on_gpu("cuda") - io_binding.bind_output("Y", "cuda", 0, np.float32, [3, 2], output.data_ptr()) - - # Sync if different streams - io_binding.synchronize_inputs() - - # Invoke Run - session.run_with_iobinding(io_binding) - - # Sync if different streams - io_binding.synchronize_outputs() - - # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host - # here) - ort_output_vals = io_binding.copy_outputs_to_cpu()[0] - # Validate results - self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals)) - - # Validate if ORT actually wrote to pre-allocated buffer by copying the allocated buffer - # to the host and validating its contents - ort_output_vals_in_cpu = output.numpy() - # Validate results - self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals_in_cpu)) - - def test_bind_input_types(self): - """ - Test I/O binding with various input data types. - """ - sess_options = onnxrt.SessionOptions() - sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) - self.assertTrue(sess_options.has_providers()) - opset = onnx_opset_version() - device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) - - for dtype in [ - np.float32, - # np.float64, - np.int32, - # np.uint32, - np.int64, - # np.uint64, - # np.int16, - # np.uint16, - # np.int8, - np.uint8, - np.float16, - np.bool_, - ]: - with self.subTest(dtype=dtype, inner_device=str(device)): - x = np.arange(8).reshape((-1, 2)).astype(dtype) - proto_dtype = helper.np_dtype_to_tensor_dtype(x.dtype) - - X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806 - Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806 - - # inference - node_add = helper.make_node("Identity", ["X"], ["Y"]) - - # graph - graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) - model_def = helper.make_model( - graph_def, - producer_name="dummy", - ir_version=7, - producer_version="0", - opset_imports=[helper.make_operatorsetid("", opset)], - ) - - sess = onnxrt.InferenceSession(model_def.SerializeToString(), sess_options=sess_options) - - bind = SessionIOBinding(sess._sess) - ort_value = C_OrtValue.ortvalue_from_numpy(x, device) - bind.bind_ortvalue_input("X", ort_value) - bind.bind_output("Y", device) - sess._sess.run_with_iobinding(bind, None) - ortvaluevector = bind.get_outputs() - self.assertIsInstance(ortvaluevector, OrtValueVector) - ortvalue = bind.get_outputs()[0] - y = ortvalue.numpy() - assert_almost_equal(x, y) - - bind = SessionIOBinding(sess._sess) - bind.bind_input("X", device, dtype, x.shape, ort_value.data_ptr()) - bind.bind_output("Y", device) - sess._sess.run_with_iobinding(bind, None) - ortvalue = bind.get_outputs()[0] - y = ortvalue.numpy() - assert_almost_equal(x, y) - - def test_bind_onnx_types_from_torch(self): - """ - Test I/O binding with various input data types. - """ - sess_options = onnxrt.SessionOptions() - sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) - self.assertTrue(sess_options.has_providers()) - opset = onnx_opset_version() - - for dtype in [ - torch.float32, - torch.float16, - torch.bfloat16, - torch.int32, - torch.int64, - ]: - with self.subTest(dtype=dtype): - proto_dtype = self.torch_to_onnx_type(dtype) - - x_ = helper.make_tensor_value_info("X", proto_dtype, [None]) - y_ = helper.make_tensor_value_info("Y", proto_dtype, [None]) - node_add = helper.make_node("Identity", ["X"], ["Y"]) - graph_def = helper.make_graph([node_add], "lr", [x_], [y_], []) - model_def = helper.make_model( - graph_def, - producer_name="dummy", - ir_version=10, - producer_version="0", - opset_imports=[helper.make_operatorsetid("", opset)], - ) - sess = onnxrt.InferenceSession(model_def.SerializeToString(), sess_options=sess_options) - - dev = "cuda" if torch.cuda.is_available() else "cpu" - device = ( - C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) - if dev == "cuda" - else C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) - ) - - x = torch.arange(8, dtype=dtype, device=dev) - y = torch.empty(8, dtype=dtype, device=dev) - - bind = SessionIOBinding(sess._sess) - bind.bind_input("X", device, proto_dtype, x.shape, x.data_ptr()) - bind.bind_output("Y", device, proto_dtype, y.shape, y.data_ptr()) - sess._sess.run_with_iobinding(bind, None) - self.assertTrue(torch.equal(x, y)) - - -if __name__ == "__main__": - unittest.main(verbosity=1) diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index bc22864304567..8ab58adbeeb74 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -26,7 +26,7 @@ static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { } OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { - Ort::InitApi(api->GetApi(ORT_API_VERSION)); + Ort::Global::api_ = api->GetApi(ORT_API_VERSION); OrtStatus* result = nullptr; ORT_TRY { diff --git a/tools/ci_build/github/windows/extract_nuget_files.ps1 b/tools/ci_build/github/windows/extract_nuget_files.ps1 index 20d6c1f2b63a5..ff8f63a85b97a 100644 --- a/tools/ci_build/github/windows/extract_nuget_files.ps1 +++ b/tools/ci_build/github/windows/extract_nuget_files.ps1 @@ -1,119 +1,105 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# This file is used by Zip-Nuget-Java Packaging Pipeline +# This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -# Define the directory for NuGet artifacts. +# Re-construct a build directory that contains binaries from all the different platforms we're including +# in the native ORT nuget package $nuget_artifacts_dir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" -# Create the directory if it doesn't exist. -New-Item -Path $nuget_artifacts_dir -ItemType directory -ErrorAction SilentlyContinue +New-Item -Path $nuget_artifacts_dir -ItemType directory ## .zip files -# Unzip files directly, excluding the iOS xcframework to preserve its symlinks. -Get-ChildItem -Path "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact\*" -Include *.zip -Exclude onnxruntime_ios_xcframework.*.zip | +# unzip directly +# exclude the iOS xcframework as we need to leave that zipped up to preserve symlinks +Get-ChildItem -Path $Env:BUILD_BINARIESDIRECTORY\nuget-artifact\* -Include *.zip -Exclude onnxruntime_ios_xcframework.*.zip | Foreach-Object { - # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). - $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" - Write-Output "Executing: 7z.exe $arguments" - # Directly call 7z.exe using the call operator '&' - & 7z.exe $arguments - # Check the exit code of the last command. A non-zero code indicates an error. - if ($LASTEXITCODE -ne 0) { - throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" - } + $cmd = "7z.exe x $($_.FullName) -y -o$nuget_artifacts_dir" + Write-Output $cmd + Invoke-Expression -Command $cmd } ## .tgz files -# First, extract the .tar file from the .tgz archive. -Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tgz | +# first extract the tar file from the tgz +Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tgz | Foreach-Object { - # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). - $arguments = "x", "$($_.FullName)", "-y", "-o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact", "-snld20" - Write-Output "Executing: 7z.exe $arguments" - & 7z.exe $arguments - if ($LASTEXITCODE -ne 0) { - throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" - } + $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" + Write-Output $cmd + Invoke-Expression -Command $cmd } -# Now, extract the contents from the .tar file into the final directory. -Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tar | +# now extract the actual folder structure from the tar file to the build dir +Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tar | Foreach-Object { - # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). - $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" - Write-Output "Executing: 7z.exe $arguments" - & 7z.exe $arguments - if ($LASTEXITCODE -ne 0) { - throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" - } + $cmd = "7z.exe x $($_.FullName) -y -o$nuget_artifacts_dir" + Write-Output $cmd + Invoke-Expression -Command $cmd } -# Process iOS xcframework -$xcframeworks = Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter onnxruntime_ios_xcframework.*.zip +# process iOS xcframework +$xcframeworks = Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter onnxruntime_ios_xcframework.*.zip if ($xcframeworks.Count -eq 1) { - $xcframework = $xcframeworks[0] - $target_dir = "$nuget_artifacts_dir\onnxruntime-ios-xcframework" - # Use the required filename format, removing version info. - $target_file = "$target_dir\onnxruntime.xcframework.zip" - New-Item -Path $target_dir -ItemType directory -ErrorAction SilentlyContinue + $xcframework = $xcframeworks[0] + $target_dir = "$nuget_artifacts_dir\onnxruntime-ios-xcframework" + # remove version info from filename and use required filename format + $target_file = "$target_dir\onnxruntime.xcframework.zip" + New-Item -Path $target_dir -ItemType directory - Write-Output "Copying $($xcframework.FullName) to $target_file" - Copy-Item $xcframework.FullName $target_file + Write-Output "Copy-Item $($xcframework.FullName) $target_file" + Copy-Item $xcframework.FullName $target_file } elseif ($xcframeworks.Count -gt 1) { - Write-Error "Expected at most one onnxruntime_ios_xcframework*.zip file but got: [$xcframeworks]" + Write-Error "Expected at most one onnxruntime_ios_xcframework*.zip file but got: [$xcframeworks]" } -# Copy Android AAR file. -# There should only be one .aar file for a full build. -$aars = Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.aar + +# copy android AAR. +# for full build of onnxruntime Android AAR, there should only be one .aar file +# called onnxruntime-android-x.y.z.aar or onnxruntime-training-android-x.y.z.aar but sanity check that +$aars = Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.aar if ($aars.Count -eq 1) { - $aar = $aars[0] - $aar_prefix = "onnxruntime" - if ($aar.Name -like "onnxruntime-training*") { - $aar_prefix = "onnxruntime-training" - } - $target_dir = "$nuget_artifacts_dir\$aar_prefix-android-aar" - # Remove version info from the filename for consistency. - $target_file = "$target_dir\onnxruntime.aar" - New-Item -Path $target_dir -ItemType directory -ErrorAction SilentlyContinue + $aar = $aars[0] + $aar_prefix = "onnxruntime" + if ($aar -like "onnxruntime-training*") { + $aar_prefix = "onnxruntime-training" + } + $target_dir = "$nuget_artifacts_dir\$aar_prefix-android-aar" + $target_file = "$target_dir\onnxruntime.aar" # remove '-mobile' and version info from filename + New-Item -Path $target_dir -ItemType directory - Write-Output "Copying $($aar.FullName) to $target_file" - Copy-Item $aar.FullName $target_file + Write-Output "Copy-Item $($aar.FullName) $target_file" + Copy-Item $aar.FullName $target_file } elseif ($aars.Count -gt 1) { - Write-Error "Expected at most one Android .aar file but got: [$aars]" + Write-Error "Expected at most one Android .aar file but got: [$aars]" } -# Check if this is a training pipeline by looking for a specific directory. -$is_training_pipeline = Test-Path -Path "$nuget_artifacts_dir\onnxruntime-training-win-x64-*" -if ($is_training_pipeline) { - Write-Output "onnxruntime-training-win-x64-* dir exists. This is a training pipeline." +# Check whether this is a training pipeline +$is_training_pipeline = $false +if (Test-Path -Path $nuget_artifacts_dir\onnxruntime-training-win-x64-*) { + $is_training_pipeline = $true + Write-Output "onnxruntime-training-win-x64-* dir exists. This is a training pipeline." } -# Copy onnxruntime and protoc binaries required by tests. -$destinationDir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo" +# Copy onnxruntime and protoc binaries to the binaries dir as these are required +# by Microsoft.ML.OnnxRuntime.Tests.NetCoreApp if ($is_training_pipeline) { - Copy-Item -Path "$nuget_artifacts_dir\onnxruntime-training-win-x64-*\lib\*" -Destination $destinationDir -Recurse + Copy-Item -Path $nuget_artifacts_dir\onnxruntime-training-win-x64-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo } else { - Copy-Item -Path "$nuget_artifacts_dir\onnxruntime-win-x64-*\lib\*" -Destination $destinationDir -Recurse + Copy-Item -Path $nuget_artifacts_dir\onnxruntime-win-x64-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo } -# Rename directories to remove the architecture-specific suffix. -Write-Output "Renaming onnxruntime directories..." -Get-ChildItem -Directory -Path "$nuget_artifacts_dir\onnxruntime-*" | ForEach-Object { - $dirname = $_.Name - # Find the last hyphen and remove the suffix. - $lastHyphenIndex = $dirname.LastIndexOf('-') - if ($lastHyphenIndex -gt -1) { - $newName = $dirname.Substring(0, $lastHyphenIndex) - $newPath = Join-Path -Path $_.Parent.FullName -ChildPath $newName - Write-Output "Renaming '$($_.FullName)' to '$newPath'" - Rename-Item -Path $_.FullName -NewName $newName - } +"Get-ChildItem -Directory -Path $nuget_artifacts_dir\onnxruntime-*" +$ort_dirs = Get-ChildItem -Directory -Path $nuget_artifacts_dir\onnxruntime-* +foreach ($ort_dir in $ort_dirs) +{ + # remove the last '-xxx' segment from the dir name. typically that's the architecture. + $dirname = Split-Path -Path $ort_dir -Leaf + $dirname = $dirname.SubString(0,$dirname.LastIndexOf('-')) + Write-Output "Renaming $ort_dir to $dirname" + Rename-Item -Path $ort_dir -NewName $nuget_artifacts_dir\$dirname } -# List the final artifacts. -Write-Output "Post-copy artifacts:" -Get-ChildItem -Recurse $nuget_artifacts_dir \ No newline at end of file +# List artifacts +"Post copy artifacts" +Get-ChildItem -Recurse $nuget_artifacts_dir\ diff --git a/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 b/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 index 29946dcb73f8a..01a8eebe75df2 100644 --- a/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 +++ b/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 @@ -2,81 +2,47 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget-Java Packaging Pipeline -# Define the directory for NuGet artifacts. -$nuget_artifacts_dir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" -# Create the directory if it doesn't exist. -New-Item -Path $nuget_artifacts_dir -ItemType directory -ErrorAction SilentlyContinue +New-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts -ItemType directory -## .zip files -Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.zip | +Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.zip | Foreach-Object { - # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). - $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" - Write-Output "Executing: 7z.exe $arguments" - & 7z.exe $arguments - if ($LASTEXITCODE -ne 0) { - throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" - } + $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" + Write-Output $cmd + Invoke-Expression -Command $cmd } -## .tgz files -Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tgz | +Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tgz | Foreach-Object { - # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). - # *.tar will be created after *.tgz is extracted - $arguments = "x", "$($_.FullName)", "-y", "-o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact", "-snld20" - Write-Output "Executing: 7z.exe $arguments" - & 7z.exe $arguments - if ($LASTEXITCODE -ne 0) { - throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" - } + $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" # *.tar will be created after *.tgz is extracted + Write-Output $cmd + Invoke-Expression -Command $cmd } -## .tar files -Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tar | +Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tar | Foreach-Object { - # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). - $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" - Write-Output "Executing: 7z.exe $arguments" - & 7z.exe $arguments - if ($LASTEXITCODE -ne 0) { - throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" - } + $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" + Write-Output $cmd + Invoke-Expression -Command $cmd } -# Create directory for protobuf build dependencies. -New-Item -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo" -ItemType directory -ErrorAction SilentlyContinue -# Copy CUDA libraries. -Copy-Item -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64-cuda-*\lib\*" -Destination "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo" +New-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo -ItemType directory + +Copy-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64-cuda-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo -# Install protoc via dotnet. $protocInstallDir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build" dotnet new console dotnet add package Google.Protobuf.Tools --version 3.21.12 --package-directory $protocInstallDir -if ($LASTEXITCODE -ne 0) { - throw "Error adding Google.Protobuf.Tools package. Exit code: $LASTEXITCODE" -} - -# Find and copy the protoc executable. $protocDir = Get-ChildItem -Path $protocInstallDir -Recurse -Filter "protoc.exe" | Select-Object -ExpandProperty DirectoryName -First 1 -if ($protocDir) { - Write-Output "Found protoc directory: $protocDir" - Copy-Item -Path $protocDir -Destination "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo" -} -else { - Write-Error "Could not find protoc.exe in $protocInstallDir" +Write-Output $protocDir +Copy-Item -Path $protocDir -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo + +$ort_dirs = Get-ChildItem -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory +foreach ($ort_dir in $ort_dirs) +{ + $dirname = Split-Path -Path $ort_dir -Leaf + $dirname = $dirname.SubString(0,$dirname.LastIndexOf('-')) + Write-Output "Renaming $ort_dir to $dirname" + Rename-Item -Path $ort_dir -NewName $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname } -# Rename onnxruntime directories to a generic format. -$ort_dirs = Get-ChildItem -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-*" -Directory -foreach ($ort_dir in $ort_dirs) { - $dirname = Split-Path -Path $ort_dir -Leaf - $lastHyphenIndex = $dirname.LastIndexOf('-') - if ($lastHyphenIndex -gt -1) { - $newName = $dirname.Substring(0, $lastHyphenIndex) - $newPath = Join-Path -Path $ort_dir.Parent.FullName -ChildPath $newName - Write-Output "Renaming '$($ort_dir.FullName)' to '$newPath'" - Rename-Item -Path $ort_dir.FullName -NewName $newName - } -} From c3276dea689ec89742aff7ad0a4242b4425741cb Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Wed, 3 Sep 2025 02:12:14 +0530 Subject: [PATCH 096/138] Re-enable setting default precision for OV devices (#802) Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../core/providers/openvino/openvino_provider_factory.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 480e4c068664e..1a10d9849d5cc 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -224,9 +224,7 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, pi.cache_dir = provider_options.at("cache_dir"); } - if (provider_options.contains("precision")) { - pi.precision = OpenVINOParserUtils::ParsePrecision(provider_options, pi.device_type, "precision"); - } + pi.precision = OpenVINOParserUtils::ParsePrecision(provider_options, pi.device_type, "precision"); if (provider_options.contains("reshape_input")) { pi.reshape = OpenVINOParserUtils::ParseInputShape(provider_options.at("reshape_input")); From edc51ea859f25974ed5eb6e5fc64333fa12fb41e Mon Sep 17 00:00:00 2001 From: Susanta Bhattacharjee Date: Wed, 3 Sep 2025 17:29:24 +0530 Subject: [PATCH 097/138] bf16 tensor handling related fix (#805) Handled tensors with bf16 data type and external in memory data Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../openvino/qdq_transformations/qdq_scales_fix.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index 3a39152b5d17d..4b862bdd7554b 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -951,6 +951,12 @@ void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) { } } + for (auto& node : gen_graph.original_graph.Nodes()) { + for (auto& input_def : node->InputDefs()) { + ORT_THROW_IF_ERROR(graph_utils::ConvertInMemoryDataToInline(gen_graph.original_graph, input_def->Name())); + } + } + const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors(); for (auto& [key, const_tensor_proto] : init_set) { auto tensor_proto = const_cast(const_tensor_proto); From a5bd8eea88d7eaec7dc4a7f5f4f42129c24e02f7 Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Thu, 11 Sep 2025 00:20:41 -0700 Subject: [PATCH 098/138] [OVEP] Fix to increase provider value upto 2048 char (#807) --- onnxruntime/core/session/provider_bridge_ort.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 960b9eff051be..3b9f5881a84ab 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2504,9 +2504,9 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2, // arbitrary length to validate the key/value. adjust if/when needed. // TODO: are any other input validation checks required here (and in the other functions that process // provider options)? - if (strlen(provider_options_keys[i]) > 1024 || strlen(provider_options_values[i]) > 1024) { + if (strlen(provider_options_keys[i]) > 1024 || strlen(provider_options_values[i]) > 2048) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "Maximum string length for a provider options key/value is 1024."); + "Maximum string length for a provider options key is 1024 and value is 2048."); } provider_options[provider_options_keys[i]] = provider_options_values[i]; From 02cf7e356ee2c436ef201750adeff3666625c8c5 Mon Sep 17 00:00:00 2001 From: "Klimenko, Mikhail" Date: Fri, 12 Sep 2025 08:49:29 +0200 Subject: [PATCH 099/138] Make cache_dir and num_stream options session-local (#809) Co-authored-by: Preetha Veeramalai --- .../providers/openvino/backends/basic_backend.cc | 12 ++++++------ .../core/providers/openvino/backends/basic_backend.h | 4 ++-- onnxruntime/core/providers/openvino/ov_interface.cc | 8 -------- onnxruntime/core/providers/openvino/ov_interface.h | 2 -- 4 files changed, 8 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 2f174110dd31b..a950538c7c5fd 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -242,13 +242,13 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { } } -void BasicBackend::EnableCaching() { +void BasicBackend::EnableCaching(ov::AnyMap& device_config) { // cache_dir argument has no effect when working with an embed-mode EPContext Graph if (subgraph_context_.is_ep_ctx_graph) return; if (!session_context_.cache_dir.empty() && !session_context_.so_context_enable) { LOGS_DEFAULT(INFO) << log_tag << "Enables Caching"; - OVCore::Get()->SetCache(session_context_.cache_dir.string()); + device_config.emplace(ov::cache_dir(session_context_.cache_dir.string())); } } @@ -262,7 +262,7 @@ void BasicBackend::EnableGPUThrottling(ov::AnyMap& device_config) { } } -void BasicBackend::EnableStreams() { +void BasicBackend::EnableStreams(ov::AnyMap& device_config) { // Return silently for NPU as it's currently treated as a read-only flag by the NPU plugin // and throws an exception for the same if (session_context_.device_type.find("NPU") != std::string::npos) @@ -279,7 +279,7 @@ void BasicBackend::EnableStreams() { } // Do nothing } else { - OVCore::Get()->SetStreams(session_context_.device_type, session_context_.num_streams); + device_config.emplace(ov::num_streams(session_context_.num_streams)); } } @@ -293,13 +293,13 @@ void BasicBackend::SetOVDeviceConfiguration(ov::AnyMap& device_config) { PopulateConfigValue(device_config); // Enable caching - EnableCaching(); + EnableCaching(device_config); // Setting OpenCL queue throttling for GPU EnableGPUThrottling(device_config); // Enable streams; default=1 unless overridden by user configuration - EnableStreams(); + EnableStreams(device_config); // Set the inference_num_threads property of the CPU SetNumThreads(device_config); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 5c75a9ae183e2..2cf3d3faa8b47 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -142,9 +142,9 @@ class BasicBackend : public IBackend { private: bool ValidateSubgraph(std::map>& const_outputs_map); void PopulateConfigValue(ov::AnyMap& device_config); - void EnableCaching(); + void EnableCaching(ov::AnyMap& device_config); void EnableGPUThrottling(ov::AnyMap& device_config); - void EnableStreams(); + void EnableStreams(ov::AnyMap& device_config); void SetNumThreads(ov::AnyMap& device_config); void SetOVDeviceConfiguration(ov::AnyMap& device_config); void ValidateOrtDimsAgainstPartialShape(const std::vector& ort_dims, diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 899845d4890cf..7723ce0a6c7f7 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -270,10 +270,6 @@ OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream& model_stream, "Exception while Loading Network from OVIR model file: {}", model_file_path.string()); } -void OVCore::SetCache(const std::string& cache_dir_path) { - core.set_property(ov::cache_dir(cache_dir_path)); -} - std::vector OVCore::GetAvailableDevices() const { std::vector available_devices = core.get_available_devices(); return available_devices; @@ -312,10 +308,6 @@ std::vector OVCore::GetAvailableDevices(const std::string& device_t return available_devices; } -void OVCore::SetStreams(const std::string& device_type, int num_streams) { - core.set_property(device_type, {ov::num_streams(num_streams)}); -} - std::shared_ptr OVExeNetwork::CreateInferRequest() { return OvExceptionBoundary([&]() { auto infReq = compiled_model_obj.create_infer_request(); diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 38ea883078e85..5f8fb36c1cbec 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -95,8 +95,6 @@ struct OVCore : WeakSingleton { std::vector GetAvailableDevices() const; std::vector GetAvailableDevices(const std::string& device_type) const; - void SetCache(const std::string& cache_dir_path); - void SetStreams(const std::string& device_type, int num_streams); }; class OVExeNetwork { From 58e83ef59b7cb018e9ebb46a56611d9780e95d0e Mon Sep 17 00:00:00 2001 From: Jozef Wludzik Date: Tue, 23 Sep 2025 08:30:35 +0200 Subject: [PATCH 100/138] Fix performance degradation in Ubuntu (#815) Fix issue with creating backend for every inference iteration. The issue was caused by using std::map operator[] that created a pair with key and empty value. In Ubuntu std::map insert method won't override the key value in backend_map if key exists in map (created by operator[]) --- onnxruntime/core/providers/openvino/backend_manager.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 68d15bdfdcee0..99f28439db53a 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -672,7 +672,10 @@ void BackendManager::Compute(OrtKernelContext* context) { { std::unique_lock lock(mutex_); - dynamic_backend = backend_map_[key]; + auto it = backend_map_.find(key); + if (it != backend_map_.end()) { + dynamic_backend = it->second; + } } if (!dynamic_backend) { From f8b09041afd76013d2848d5d5b2d5f3194360a6c Mon Sep 17 00:00:00 2001 From: Bartlomiej Filipek Date: Mon, 6 Oct 2025 10:07:39 -0700 Subject: [PATCH 101/138] Don't embed external initializers into the proto to avoid 2GB limit (#817) * early version, it doesn't embed initializers into the proto, but then restores the metadata so OV can read them back Signed-off-by: bfilipek * improve code, refactor into smaller functions, run the logic when there are external initializers in memory (more than one) Signed-off-by: bfilipek * revert the wrongly merged code Signed-off-by: bfilipek * Updated the condition for the new logic based on the total size of ext initializers, comments, refactoring Signed-off-by: bfilipek * Update onnxruntime/core/providers/openvino/backend_manager.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * make the condition less strict - 32MB threshold, move debug dump after the logic is executed, check for OV version Signed-off-by: bfilipek * unit test that uses ext initializers, early version Signed-off-by: bfilipek * used kOrtSessionOptionsDisableCPUEPFallback, cleanups, model is now over 2GB to show the proto limit (when the new logic for ext initializers is enabled, then the test passes) Signed-off-by: bfilipek * address code review comments Signed-off-by: bfilipek * Update onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix the Linux CI build, use PathString rather than wstring Signed-off-by: bfilipek * As agreed, disable the test as it requires OV 2025.4, while the current CI version is only 2025.2 Signed-off-by: bfilipek * add missing comment Signed-off-by: bfilipek --------- Signed-off-by: bfilipek Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../providers/openvino/backend_manager.cc | 163 ++++++++++++- .../openvino/openvino_ep_ext_init.cc | 215 ++++++++++++++++++ 2 files changed, 377 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 99f28439db53a..989d1022f1d7b 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -21,6 +21,7 @@ #include "core/providers/openvino/ov_versions/capability.h" #include "core/providers/openvino/qdq_transformations/qdq_stripping.h" #include "core/providers/openvino/qdq_transformations/qdq_scales_fix.h" +#include "../../framework/tensorprotoutils.h" namespace onnxruntime { namespace openvino_ep { @@ -453,6 +454,80 @@ static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& on #endif } +// this is a helper function to set the data fields, it duplicates ExternalDataInfo::SetExternalLocationToProto +// but we cannot use that function as it is not part of public provider api. +static void SetExternalDataFields(ONNX_NAMESPACE::TensorProto* proto_init, const void* data_ptr, int64_t data_size) { + static constexpr const char* ORT_INTERNAL_MEM_INITIALIZER = "*/_ORT_MEM_ADDR_/*"; + auto* external_data = proto_init->mutable_external_data(); + bool found_location = false, found_offset = false, found_length = false; + const int ext_data_size = external_data->size(); + proto_init->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); + + for (int j = 0; j < ext_data_size; ++j) { + auto& ext_entry = external_data->at(j); + auto& key = *ext_entry.mutable_key(); + if (key == "location") { + *ext_entry.mutable_value() = ORT_INTERNAL_MEM_INITIALIZER; + found_location = true; + } else if (key == "offset") { + *ext_entry.mutable_value() = std::to_string(reinterpret_cast(data_ptr)); + found_offset = true; + } else if (key == "length") { + *ext_entry.mutable_value() = std::to_string(data_size); + found_length = true; + } + } + + if (!found_location) { + auto* new_entry = external_data->Add(); + *new_entry->mutable_key() = "location"; + *new_entry->mutable_value() = ORT_INTERNAL_MEM_INITIALIZER; + } + if (!found_offset) { + auto* new_entry = external_data->Add(); + *new_entry->mutable_key() = "offset"; + *new_entry->mutable_value() = std::to_string(reinterpret_cast(data_ptr)); + } + if (!found_length) { + auto* new_entry = external_data->Add(); + *new_entry->mutable_key() = "length"; + *new_entry->mutable_value() = std::to_string(data_size); + } +} + +static void ReadExternalDataFields(const ONNX_NAMESPACE::TensorProto* src_init, std::string& location, size_t& offset, size_t& length) { + // Remove constness as we need to use mutable_external_data() to get the entries to read. + // The entries themselves are not modified... + auto& mutable_proto = *const_cast(src_init); + auto* entry_protos = mutable_proto.mutable_external_data(); + for (int i = 0; i < entry_protos->size(); i++) { + auto& string_entry_proto{entry_protos->at(i)}; + const auto& pb_key{*(string_entry_proto.mutable_key())}; + const auto& pb_value{*(string_entry_proto.mutable_value())}; + if (pb_key == "location") { + location = pb_value; + } else if (pb_key == "offset") { + const auto res = std::from_chars(pb_value.data(), pb_value.data() + pb_value.size(), offset); + if (res.ec != std::errc()) { + std::ostringstream err_msg; + err_msg << "External data in memory has invalid offset field: " + << src_init->name() << "], location: " << location + << ", offset: " << pb_value; + ORT_THROW(err_msg.str()); + } + } else if (pb_key == "length") { + const auto res = std::from_chars(pb_value.data(), pb_value.data() + pb_value.size(), length); + if (res.ec != std::errc()) { + std::ostringstream err_msg; + err_msg << "External data in memory has invalid length field: " + << src_init->name() << "], location: " << location + << ", length: " << pb_value; + ORT_THROW(err_msg.str()); + } + } + } +} + std::unique_ptr BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, @@ -529,12 +604,98 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, return model_proto; } else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled"; + + // scan ext initializers: + std::unordered_map> external_initializers_offset_and_length; + std::string tempLocation; + size_t extInitializerTotalSize = 0; + if (session_context_.has_external_weights) { + auto allInitializers = subgraph.GetAllInitializedTensors(); + for (auto& [name, tp] : allInitializers) { + if (utils::HasExternalDataInMemory(*tp)) { + size_t offset = 0; + size_t length = 0; + ReadExternalDataFields(tp, tempLocation, offset, length); + extInitializerTotalSize += length; + external_initializers_offset_and_length[name] = {offset, length}; + } + } + } + + // when we have external weights in memory, the model proto will actually embed those + // and bloat the serialized string. We can avoid that by not including the data in the proto + // but then we have to update those initializers and set the external_data fields to mem_addr tag... + // proto is limited to 2GB, but let's use 32MB as threshold to be conservative and still gain some memory reductions. +#if (((OPENVINO_VERSION_MAJOR == 2025) && (OPENVINO_VERSION_MINOR > 3)) || (OPENVINO_VERSION_MAJOR > 2025)) + constexpr size_t MAX_EMBEDDED_INITIALIZER_SIZE = 1024 * 1024 * 32; + const bool include_initializer_data_in_proto = !(session_context_.has_external_weights && + external_initializers_offset_and_length.size() > 1 && + extInitializerTotalSize >= MAX_EMBEDDED_INITIALIZER_SIZE); +#else + const bool include_initializer_data_in_proto = true; +#endif + + auto model = subgraph.CreateModel(logger); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - subgraph.ToProto(*model_proto->mutable_graph(), true, true); + subgraph.ToProto(*model_proto->mutable_graph(), /*include_initializers*/true, + /*include_outer_scope_args*/true, /*execution_order*/0, /*include_initializer_data*/include_initializer_data_in_proto); + print_model_proto_duration(); + + if (!include_initializer_data_in_proto) { + LOGS(logger, INFO) << "Initializer data is not included in the model proto. Updating metadata..., total size " << extInitializerTotalSize / (1024 * 1024) << " MB in " << external_initializers_offset_and_length.size() << " initializers"; + auto* graph_proto = model_proto->mutable_graph(); + auto* proto_initializers = graph_proto->mutable_initializer(); + + std::unordered_map proto_initializer_map; + for (int i = 0, n = proto_initializers->size(); i < n; ++i) { + auto& proto_init = proto_initializers->at(i); + proto_initializer_map[proto_init.name()] = &proto_init; + } + + for (const auto& [name, src_init] : subgraph.GetAllInitializedTensors()) { + auto it = proto_initializer_map.find(name); + if (it == proto_initializer_map.end()) + continue; + + auto* proto_init = it->second; + + // If the proto initializer is missing data, fill it in + if (!proto_init->has_raw_data() && src_init->has_raw_data()) { + *proto_init->mutable_raw_data() = src_init->raw_data(); + } + + // Only set in-memory external_data fields if the data is in memory + if (src_init->has_raw_data()) { + LOGS(logger, VERBOSE) << "In-memory initializer RAW: " + << src_init->name() + << ", data_type: " << src_init->data_type() + << ", raw_data size: " << src_init->raw_data().size(); + + SetExternalDataFields(proto_init, src_init->raw_data().data(), src_init->raw_data().size()); + } else if (onnxruntime::utils::HasExternalDataInMemory(*src_init)) { + auto it_ext = external_initializers_offset_and_length.find(name); + if (it_ext == external_initializers_offset_and_length.end()) { + std::ostringstream err_msg; + err_msg << "Initializer marked as external in memory but missing offset/length info: " << src_init->name(); + ORT_THROW(err_msg.str()); + } + const size_t offset = it_ext->second.first; + const size_t length = it_ext->second.second; + + LOGS(logger, VERBOSE) << "In-memory initializer EXT: " << src_init->name() << ", size: " << length; + + SetExternalDataFields(proto_init, (const void*)offset, length); + } else { + LOGS(logger, VERBOSE) << "File-based initializer: " << src_init->name() << ", data_type: " << src_init->data_type(); + } + } + } + DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); + return model_proto; } } diff --git a/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc b/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc new file mode 100644 index 0000000000000..21ec61c2d2e3f --- /dev/null +++ b/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc @@ -0,0 +1,215 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/util/include/test/test_environment.h" +#include "test/unittest_util/qdq_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" +#include "onnxruntime_session_options_config_keys.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + +extern std::unique_ptr ort_env; + +class OVEP_ExtInit_Tests : public ::testing::TestWithParam {}; + +namespace { + +std::vector LoadFileToMemory(const std::string& path) { + std::ifstream file(path, std::ios::binary | std::ios::ate); + if (!file.is_open()) { + return std::vector(); + } + std::streamsize size = file.tellg(); + file.seekg(0, std::ios::beg); + std::vector buffer(static_cast(size)); + if (!file.read(reinterpret_cast(buffer.data()), size)) { + return std::vector(); + } + return buffer; +} + +auto ProbeDevice(const std::string& device) { + static std::map is_present; + if (is_present.find(device) == is_present.end()) { + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + ov_options["device_type"] = device; + try { + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + is_present[device] = true; + } catch (...) { + is_present[device] = false; + } + } + return is_present[device]; +} +} // namespace detail + +namespace onnxruntime { +namespace test { + +// this test requiresOV 2025.4+ to run, currently CI uses OV 2025.2, so the test will be disabled until OV is updated +TEST_P(OVEP_ExtInit_Tests, DISABLED_ModelFromExtInit) { + const auto& device = GetParam(); + if (!ProbeDevice(device)) + GTEST_SKIP() << device + " is not available on this machine"; + + // Model and weights file paths + const std::string model_path = "ovep_ext_init_test.onnx"; + const std::string weights_path = "ovep_ext_init_test.onnx.data"; + const size_t num_initializers = 8; + const size_t floats_per_initializer = 64 * 1024 * 1024; // 64 millions floats per initializer, 256MB + const size_t total_floats = num_initializers * floats_per_initializer; + const size_t total_bytes = total_floats * sizeof(float); + // min size threshold for new logic with ext initializers + ASSERT_GE(total_bytes, 32 * 1024 * 1024); + + // 1. Create initializers + std::vector> initializer_data; + for (size_t i = 0; i < num_initializers; ++i) + initializer_data.emplace_back(floats_per_initializer, static_cast(i + 1)); // W0:1, W1:2... + + // 2. Build ONNX model with 4 external initializers, and 4 ADD nodes + { + ModelProto model_proto; + model_proto.set_ir_version(7); + model_proto.set_producer_name("openvino_extinit_test"); + model_proto.set_producer_version("1.0"); + model_proto.set_domain(""); + model_proto.set_model_version(1); + + auto* graph = model_proto.mutable_graph(); + graph->set_name("TestGraph"); + + // Input: shape [floats_per_initializer] + auto* input = graph->add_input(); + input->set_name("X"); + auto* input_type = input->mutable_type()->mutable_tensor_type(); + input_type->set_elem_type(TensorProto_DataType_FLOAT); + input_type->mutable_shape()->add_dim()->set_dim_value(floats_per_initializer); + + // Output: shape [floats_per_initializer] + auto* output = graph->add_output(); + output->set_name("Y"); + auto* output_type = output->mutable_type()->mutable_tensor_type(); + output_type->set_elem_type(TensorProto_DataType_FLOAT); + output_type->mutable_shape()->add_dim()->set_dim_value(floats_per_initializer); + + auto* opset_import = model_proto.add_opset_import(); + opset_import->set_domain(""); + opset_import->set_version(19); + + // Add initializers as external data + size_t offset = 0; + std::vector initializer_names; + for (size_t i = 0; i < num_initializers; ++i) { + std::string name = "W" + std::to_string(i); + initializer_names.push_back(name); + TensorProto* initializer = graph->add_initializer(); + initializer->set_name(name); + initializer->set_data_type(TensorProto_DataType_FLOAT); + initializer->add_dims(floats_per_initializer); + initializer->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto* ext = initializer->add_external_data(); + ext->set_key("location"); + ext->set_value(weights_path); + ext = initializer->add_external_data(); + ext->set_key("offset"); + ext->set_value(std::to_string(offset)); + ext = initializer->add_external_data(); + ext->set_key("length"); + ext->set_value(std::to_string(floats_per_initializer * sizeof(float))); + offset += floats_per_initializer * sizeof(float); + } + + // nodes: X -> Add with Init[0] -> ... -> output Y + std::string prev_output = "X"; + std::string node_output; + for (size_t i = 0; i < num_initializers; ++i) { + node_output = (i == num_initializers - 1) ? "Y" : "A" + std::to_string(i); + auto* add_node = graph->add_node(); + add_node->set_op_type("Add"); + add_node->add_input(prev_output); + add_node->add_input(initializer_names[i]); + add_node->add_output(node_output); + prev_output = node_output; + } + + // Save model + std::ofstream model_file(model_path, std::ios::binary); + ASSERT_TRUE(model_proto.SerializeToOstream(&model_file)); + model_file.close(); + } + + // 3. Save weights file (concatenate all initializers) + { + std::ofstream weights_file(weights_path, std::ios::binary); + ASSERT_TRUE(weights_file.is_open()); + for (const auto& w : initializer_data) { + weights_file.write(reinterpret_cast(w.data()), w.size() * sizeof(float)); + } + weights_file.close(); + } + + // 4. Load model and weights into memory + std::vector model_data = LoadFileToMemory(model_path); + std::vector weights_data = LoadFileToMemory(weights_path); + + // 5. Prepare external initializer info + PathString weights_name_path(weights_path.begin(), weights_path.end()); + std::vector names_path = {weights_name_path}; + std::vector buffers = {reinterpret_cast(weights_data.data())}; + std::vector buffer_sizes = {weights_data.size()}; + + // 6. Set up session options with OpenVINO + Ort::SessionOptions session_options; + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); + session_options.SetIntraOpNumThreads(1); + std::unordered_map ov_options = { {"device_type", device } }; + session_options.AppendExecutionProvider_OpenVINO_V2(ov_options); + session_options.AddExternalInitializersFromFilesInMemory(names_path, buffers, buffer_sizes); + + // 7. Create session from memory + Ort::Session session(*ort_env, model_data.data(), model_data.size(), session_options); + + // 8. Run inference to verify weights are loaded + std::vector input_data(floats_per_initializer, 2.0f); + std::vector input_shape = {static_cast(floats_per_initializer)}; + Ort::MemoryInfo mem_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, OrtMemTypeDefault); + Ort::Value input_tensor = Ort::Value::CreateTensor(mem_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size()); + + std::vector input_names = {"X"}; + std::vector output_names = {"Y"}; + std::vector output_tensors(1); + + session.Run(Ort::RunOptions{nullptr}, input_names.data(), &input_tensor, 1, output_names.data(), output_tensors.data(), 1); + + // Check output: should be input + W0 + W1 + W2... + auto* out_data = output_tensors[0].GetTensorMutableData(); + float expected = input_data[0]; + for (size_t i = 0; i < num_initializers; ++i) { + expected += initializer_data[i][0]; + } + + for (size_t i = 0; i < floats_per_initializer; ++i) + ASSERT_FLOAT_EQ(out_data[i], expected); + + // Cleanup + std::filesystem::remove(model_path); + std::filesystem::remove(weights_path); +} +INSTANTIATE_TEST_SUITE_P(OVEP_Tests, + OVEP_ExtInit_Tests, + ::testing::Values("CPU", "GPU", "NPU")); + +} // namespace test +} // namespace onnxruntime From d102554793f340fcb77c2edba022b4ffcfd8cb89 Mon Sep 17 00:00:00 2001 From: Bartlomiej Filipek Date: Wed, 8 Oct 2025 21:54:30 -0700 Subject: [PATCH 102/138] Fix Regression in Model PSS and PSR, add check for zero-size initializers (#825) * check for the size of RAW data, skip if it's zero Signed-off-by: bfilipek * Update onnxruntime/core/providers/openvino/backend_manager.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Signed-off-by: bfilipek Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/providers/openvino/backend_manager.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 989d1022f1d7b..b1cc42cc66ce8 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -673,8 +673,10 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, << src_init->name() << ", data_type: " << src_init->data_type() << ", raw_data size: " << src_init->raw_data().size(); - - SetExternalDataFields(proto_init, src_init->raw_data().data(), src_init->raw_data().size()); + if (src_init->raw_data().size() > 0) + SetExternalDataFields(proto_init, src_init->raw_data().data(), src_init->raw_data().size()); + else + LOGS(logger, VERBOSE) << "Initializer has empty raw_data: skipping initializer '" << src_init->name() << "'..."; } else if (onnxruntime::utils::HasExternalDataInMemory(*src_init)) { auto it_ext = external_initializers_offset_and_length.find(name); if (it_ext == external_initializers_offset_and_length.end()) { From b685871cd4cbbc8c29cd1a1e1376420e1a76ad4f Mon Sep 17 00:00:00 2001 From: Bartlomiej Filipek Date: Fri, 10 Oct 2025 11:49:39 -0700 Subject: [PATCH 103/138] When dynamic shapes are used the proto might be overriden and offsets from initializers might become invalid (#828) Signed-off-by: bfilipek --- onnxruntime/core/providers/openvino/backend_manager.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index b1cc42cc66ce8..d679ac65720db 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -83,6 +83,10 @@ BackendManager::BackendManager(SessionContext& session_context, subgraph_context_.subgraph_name = fused_node.Name(); + if (ModelHasSymbolicInputDims(subgraph)) { + subgraph_context_.has_dynamic_input_shape = true; + } + ptr_stream_t model_stream; std::unique_ptr model_proto; if (subgraph_context_.is_ep_ctx_graph) { @@ -119,8 +123,7 @@ BackendManager::BackendManager(SessionContext& session_context, backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights); } - if (ModelHasSymbolicInputDims(subgraph)) { - subgraph_context_.has_dynamic_input_shape = true; + if (subgraph_context_.has_dynamic_input_shape) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; if ((!session_context_.disable_dynamic_shapes && (session_context_.device_type.find("CPU") != std::string::npos || @@ -609,7 +612,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, std::unordered_map> external_initializers_offset_and_length; std::string tempLocation; size_t extInitializerTotalSize = 0; - if (session_context_.has_external_weights) { + if (session_context_.has_external_weights && !subgraph_context_.has_dynamic_input_shape) { auto allInitializers = subgraph.GetAllInitializedTensors(); for (auto& [name, tp] : allInitializers) { if (utils::HasExternalDataInMemory(*tp)) { From 65bbecc0bf81d44287b485f4a94199a81a650bb9 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Thu, 2 Oct 2025 16:47:37 -0700 Subject: [PATCH 104/138] trigger stateful path for Phisilica model Co-author: Beheshti, Nazanin --- .../core/providers/openvino/ov_interface.cc | 8 ++--- .../openvino/ov_stateful_patch_utils.cc | 32 ++++++++++++------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 7723ce0a6c7f7..627183fb10d51 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -360,10 +360,10 @@ void OVInferRequest::Infer() { StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) : OVInferRequest(std::move(infer_request)), target_device(device) { - bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); - if (gpu_or_npu) { - prefill_use_full_chat_history = true; - } + // bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); + // if (gpu_or_npu) { + // prefill_use_full_chat_history = true; + // } } void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type, diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index b48b0efde7ab6..f86d2d54fc381 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -72,6 +72,14 @@ void FuseCacheReorder(std::shared_ptr ov_model, main_input_name = "input_ids"; } + if (ModelHasInputOutputNames(ov_model, "input_hidden_states")) { + main_input_name = "input_hidden_states"; + } + + if (ModelHasInputOutputNames(ov_model, "/model/embed_tokens/Gather_output_0")) { + main_input_name = "/model/embed_tokens/Gather_output_0"; + } + auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape({std::move(input_batch)})); @@ -121,20 +129,22 @@ void MakeStateful(std::shared_ptr& ov_model, void PatchStatefulDecoder(std::shared_ptr model) { std::vector key_value_input_names; std::vector not_kv_inputs; - for (const ov::Output& input : model->inputs()) { - auto& names = input.get_names(); - - bool found = false; - for (auto& name : names) { - if (name.find("key_values") != std::string::npos) { - key_value_input_names.push_back(name); + const auto& params = model->get_parameters(); + bool found = false; + for (auto i = 0; i < params.size(); i++) { + auto param_name = params.at(i)->output(0).get_any_name(); + if (param_name.find("key_values") != std::string::npos) { + key_value_input_names.push_back(param_name); + found = true; + } else if (param_name.find("key") != std::string::npos) { + key_value_input_names.push_back(param_name); + found = true; + } else if (param_name.find("value") != std::string::npos) { + key_value_input_names.push_back(param_name); found = true; - break; - } } - if (!found) { - not_kv_inputs.push_back(input.get_any_name()); + not_kv_inputs.push_back(param_name); } } From 1e132f362131bf5ed045d189f5d116dfd771287b Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Fri, 10 Oct 2025 17:21:45 -0700 Subject: [PATCH 105/138] unify the code --- onnxruntime/core/providers/openvino/ov_interface.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 627183fb10d51..99e310185e9e4 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -360,10 +360,14 @@ void OVInferRequest::Infer() { StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) : OVInferRequest(std::move(infer_request)), target_device(device) { - // bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); - // if (gpu_or_npu) { - // prefill_use_full_chat_history = true; - // } + bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); + + // check if there is input_ids tensors and if the tensor type is int64, + // because logic prefill_use_full_chat_history is only for specific inputs and data type + auto input_ids_opt = FindTensor("input_ids"); + if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() != ov::element::i64) { + prefill_use_full_chat_history = true; + } } void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type, From 59f22e171169d69eec5d507bcfa84b5a113ab816 Mon Sep 17 00:00:00 2001 From: Rajeev Sekar Date: Tue, 14 Oct 2025 21:28:23 +0530 Subject: [PATCH 106/138] [OVEP] fixed NPU exception message when CPU fallback is disabled (#832) * fixed NPU exception message when CPU fallback is disabled * fixed lint issues --- .../providers/openvino/backend_manager.cc | 37 ++++++++++++++++--- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index d679ac65720db..4e9c0f912c825 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -170,7 +171,10 @@ BackendManager::BackendManager(SessionContext& session_context, exception_str.find("intel_npu") != std::string::npos) { // Handle NPU device related errors #ifndef NDEBUG - ORT_THROW(exception_str + "\nModel needs to be recompiled\n"); + std::string suffix = session_context_.so_disable_cpu_ep_fallback ? + "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n" : + "\nModel needs to be recompiled\n"; + ORT_THROW(exception_str + suffix); #else std::string error_message = "UNKNOWN NPU ERROR"; std::string error_code = "code 0x0"; @@ -183,7 +187,10 @@ BackendManager::BackendManager(SessionContext& session_context, if (std::regex_search(exception_str, matches, error_code_pattern)) { error_code = matches[0]; } - throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n"); + std::string suffix = session_context_.so_disable_cpu_ep_fallback ? + "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n" : + "\nModel needs to be recompiled\n"; + throw std::runtime_error(error_message + ", " + error_code + suffix); #endif } else { ORT_THROW(exception_str); @@ -631,8 +638,8 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, // proto is limited to 2GB, but let's use 32MB as threshold to be conservative and still gain some memory reductions. #if (((OPENVINO_VERSION_MAJOR == 2025) && (OPENVINO_VERSION_MINOR > 3)) || (OPENVINO_VERSION_MAJOR > 2025)) constexpr size_t MAX_EMBEDDED_INITIALIZER_SIZE = 1024 * 1024 * 32; - const bool include_initializer_data_in_proto = !(session_context_.has_external_weights && - external_initializers_offset_and_length.size() > 1 && + const bool include_initializer_data_in_proto = !(session_context_.has_external_weights && + external_initializers_offset_and_length.size() > 1 && extInitializerTotalSize >= MAX_EMBEDDED_INITIALIZER_SIZE); #else const bool include_initializer_data_in_proto = true; @@ -642,7 +649,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, auto model = subgraph.CreateModel(logger); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - subgraph.ToProto(*model_proto->mutable_graph(), /*include_initializers*/true, + subgraph.ToProto(*model_proto->mutable_graph(), /*include_initializers*/true, /*include_outer_scope_args*/true, /*execution_order*/0, /*include_initializer_data*/include_initializer_data_in_proto); print_model_proto_duration(); @@ -881,7 +888,25 @@ void BackendManager::Compute(OrtKernelContext* context) { ORT_THROW(msg); } } else { - ORT_THROW(ex.what()); + std::string exception_str = ex.what(); + if (session_context_.so_disable_cpu_ep_fallback){ + std::string error_message = "UNKNOWN NPU ERROR"; + std::string error_code = "code 0x0"; + std::regex error_message_pattern(R"(\bZE_\w*\b)"); + std::regex error_code_pattern("code 0x[0-9a-fA-F]+"); + std::smatch matches; + if (std::regex_search(exception_str, matches, error_message_pattern)) { + error_message = matches[0]; + } + if (std::regex_search(exception_str, matches, error_code_pattern)) { + error_code = matches[0]; + } + std::string suffix = "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n" ; + throw std::runtime_error(error_message + ", " + error_code + suffix); + } + else{ + ORT_THROW(exception_str); + } } #endif } From 5c40da54cfec2b7a9ca72342a76840772e2de711 Mon Sep 17 00:00:00 2001 From: liang Date: Wed, 15 Oct 2025 01:18:50 +0800 Subject: [PATCH 107/138] CVS-174886: Make onnxruntime tests pass on OpenVINO (#790) * Change the checker to support threshold for uint16/uint4/int4 * Disable loop case for OV-EP due to floating nodes * Disable the case for OV-EP due to input conflict * Add threshold for uint8/int8/uint16 cases * Fix build warning * New changes after rebase * Change code based on review --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../test/contrib_ops/quantize_ops_test.cc | 2 + .../providers/cpu/controlflow/loop_test.cc | 3 +- .../test/providers/cpu/tensor/cast_op_test.cc | 3 + .../providers/cpu/tensor/concat_op_test.cc | 2 + .../cpu/tensor/quantize_linear_test.cc | 7 ++ .../providers/cpu/tensor/resize_op_test.cc | 2 + .../providers/cpu/tensor/slice_op.test.cc | 4 + onnxruntime/test/unittest_util/checkers.cc | 76 +++++++++++++++++-- 8 files changed, 92 insertions(+), 7 deletions(-) diff --git a/onnxruntime/test/contrib_ops/quantize_ops_test.cc b/onnxruntime/test/contrib_ops/quantize_ops_test.cc index db685967ae5ff..de10f14ef4538 100644 --- a/onnxruntime/test/contrib_ops/quantize_ops_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_ops_test.cc @@ -287,6 +287,7 @@ TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_int8) { 127, -127, 127, -128, 127, -128}); + test.SetOutputAbsErr("y", 1.0f); // Disable Tensorrt EP due to error: node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } @@ -311,6 +312,7 @@ TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_uint16) { 32769, 32765, 65535, 0, 65535, 0}); + test.SetOutputAbsErr("y", 1.0f); // Disable Tensorrt EP due to error: unsupported data type test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 07cd2114372dd..0bed6b6e9abee 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -828,7 +828,8 @@ TEST(Loop, Opset11WithNoVariadicInputsAndOutputs) { test.AddOutput("loop_scan_out", {1}, {1.0f}); // Disable TensorRT on unsupported data type BOOL - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + // Disable OpenVino for floating nodes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } // Test a combination of things: diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 8f4c4ff0896ba..8f2eac2d05792 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -853,6 +853,9 @@ TEST(CastOpTest, Int32ToInt4x2OddNumberOfElements) { } TEST(CastOpTest, Int32ToInt4x2EmptyTensor) { + if (DefaultOpenVINOExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "The OpenVINO not support 0 size input"; + } // GIVEN const std::vector empty_shape{0}; const std::vector empty_input = {}; diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index b5e13c6377ccb..5f08b6df6785d 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -73,6 +73,7 @@ TEST(ConcatOpTest, Concat1D_2) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, // TensorRT: no support for dynamic shape tensor kNnapiExecutionProvider, // NNAPI: concat does not support 0 size input + kOpenVINOExecutionProvider, // OpenVINO: does not support 0 size input kQnnExecutionProvider}); // QNN: not support dynamic shape tensor } @@ -118,6 +119,7 @@ TEST(ConcatOpTest, Concat2D_3) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, // TensorRT: no support for dynamic shape tensor kNnapiExecutionProvider, // NNAPI: concat does not support 0 size input + kOpenVINOExecutionProvider, // OpenVINO: does not support 0 size input kQnnExecutionProvider}); // QNN: not support dynamic shape tensor } diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 46acb5a730a78..18eec7d1b42a3 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -448,6 +448,7 @@ TEST(QuantizeLinearOpTest, Uint16) { 32769, 32765, 65535, 0, 65535, 0}); + test.SetOutputAbsErr("y", 1.0f); // Disable Tensorrt EP due to error: unsupported data type test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); @@ -477,6 +478,7 @@ TEST(QuantizeLinearOpTest, Int16) { 32767, -32768, 32767, -32768, 32767, -32768}); + test.SetOutputAbsErr("y", 1.0f); // Disable Tensorrt EP due to error: unsupported data type test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); @@ -501,6 +503,7 @@ TEST(QuantizeLinearOpTest, Int4) { test.AddOutput("y", dims, {Int4x2(-8, -7), Int4x2(-1, 1), Int4x2(2, 7), Int4x2(7, unused_val)}); + test.SetOutputAbsErr("y", 1.0f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } @@ -568,6 +571,7 @@ TEST(QuantizeLinearOpTest, OddLarge_Int4) { test.AddInput("scale", {}, {scale}, true); test.AddInput("zero_point", {}, {Int4x2(zp, unused_val)}, true); test.AddOutput("y", dims, output); + test.SetOutputAbsErr("y", 1.0f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } @@ -594,6 +598,7 @@ TEST(QuantizeLinearOpTest, OddLarge_UInt4) { test.AddInput("scale", {}, {scale}, true); test.AddInput("zero_point", {}, {UInt4x2(zp, unused_val)}, true); test.AddOutput("y", dims, output); + test.SetOutputAbsErr("y", 1.0f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } @@ -611,6 +616,7 @@ TEST(QuantizeLinearOpTest, Int8_NegativeZeroPoint) { test.AddInput("y_scale", {}, {.039215686f}); test.AddInput("y_zero_point", {}, {-23}); test.AddOutput("y", dims, {-23, 28, 53, 104, 127, -74, -128, -128}); + test.SetOutputAbsErr("y", 1.0f); // Disable Tensorrt EP due to the error, node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } @@ -628,6 +634,7 @@ TEST(QuantizeLinearOpTest, Int8_PositiveZeroPoint) { test.AddInput("y_scale", {}, {.039215686f}); test.AddInput("y_zero_point", {}, {23}); test.AddOutput("y", dims, {23, 74, 99, 127, 127, -28, -104, -128}); + test.SetOutputAbsErr("y", 1.0f); // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index bb053bc37ce30..f3b0695bdbd9c 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -308,6 +308,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { std::vector Y = {2, 4}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); + test.SetOutputAbsErr("Y", 1.0f); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", @@ -647,6 +648,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe std::vector Y = {1, 7, 12}; test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); + test.SetOutputAbsErr("Y", 1.0f); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch // DML: results mismatch diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index 5b2865a3feed7..657f3fe9c127a 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -540,6 +540,10 @@ TEST(SliceTest, Slice1D_ReverseAllAxes_1) { GTEST_SKIP() << "Skipping because of the following error: Expected output shape [{4}] did not match run output shape [{0}] for output"; } + if (DefaultOpenVINOExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: The input ends do not support int max when step is negative."; + } + RunSliceTest({4}, {1.0f, 2.0f, 3.0f, 4.0f}, {-1}, diff --git a/onnxruntime/test/unittest_util/checkers.cc b/onnxruntime/test/unittest_util/checkers.cc index 7b2a5a4a4ff2f..794bd24310cd1 100644 --- a/onnxruntime/test/unittest_util/checkers.cc +++ b/onnxruntime/test/unittest_util/checkers.cc @@ -225,17 +225,27 @@ template <> struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, const std::string& /*provider_type*/) const { - ORT_UNUSED_PARAMETER(params); + const bool has_abs_err = params.absolute_error.has_value(); + Tensor expected_sorted, actual_sorted; const Int4x2* cur_expected; const Int4x2* cur_actual; const auto size = narrow(actual.Shape().Size()); cur_expected = expected.Data(); cur_actual = actual.Data(); + double threshold = 0.0f; + if (has_abs_err) { + threshold = *(params.absolute_error); + } for (size_t i = 0; i < size; ++i) { size_t r = i >> 1; size_t c = i & 0x1; - EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; + // TODO: the relative error is not used for int4 yet. + if (has_abs_err) { + EXPECT_NEAR(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c), threshold) << "i:" << i; + } else { + EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; + } } } }; @@ -244,17 +254,28 @@ template <> struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, const std::string& /*provider_type*/) const { - ORT_UNUSED_PARAMETER(params); + const bool has_abs_err = params.absolute_error.has_value(); + Tensor expected_sorted, actual_sorted; const UInt4x2* cur_expected; const UInt4x2* cur_actual; const auto size = narrow(actual.Shape().Size()); cur_expected = expected.Data(); cur_actual = actual.Data(); - for (size_t i = 0; i < size; ++i) { + double threshold = 0.0f; + if (has_abs_err) { + threshold = *(params.absolute_error); + } + + for (size_t i = 0; i < static_cast(size); ++i) { size_t r = i >> 1; size_t c = i & 0x1; - EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; + // TODO: the relative error is not used for int4 yet. + if (has_abs_err) { + EXPECT_NEAR(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c), threshold) << "i:" << i; + } else { + EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; + } } } }; @@ -292,7 +313,7 @@ struct TensorCheck { // For any other EPs, we still expect an exact match for the results // TODO: Verify if DML can possibly have a ROUNDING_MODE parameter and conform to the other EPs #41968513 if ((provider_type == kNnapiExecutionProvider || provider_type == kDmlExecutionProvider || - provider_type == kXnnpackExecutionProvider) && + provider_type == kXnnpackExecutionProvider || provider_type == kOpenVINOExecutionProvider) && (has_abs_err || has_rel_err)) { double threshold = has_abs_err ? *(params.absolute_error) : 0.0; @@ -357,6 +378,49 @@ struct TensorCheck { } }; +template <> +struct TensorCheck { + void operator()(const Tensor& expected, + const Tensor& actual, + const ValidateOutputParams& params, + const std::string& ) const { + const bool has_abs_err = params.absolute_error.has_value(); + const bool has_rel_err = params.relative_error.has_value(); + + Tensor expected_sorted, actual_sorted; + const uint16_t* cur_expected; + const uint16_t* cur_actual; + const auto size = actual.Shape().Size(); + if (params.sort_output) { + sort_expected_and_actual_buffers(expected, expected_sorted, actual, actual_sorted); + cur_expected = expected_sorted.Data(); + cur_actual = actual_sorted.Data(); + } else { + cur_expected = expected.Data(); + cur_actual = actual.Data(); + } + + if (has_abs_err || has_rel_err) { + double threshold = has_abs_err ? *(params.absolute_error) + : 0.0; + + for (int64_t i = 0; i < size; ++i) { + if (has_rel_err) { + EXPECT_NEAR(cur_expected[i], cur_actual[i], + *(params.relative_error) * cur_expected[i]) // expected[i] is unsigned, can't be negative + << "i:" << i; + } else { // has_abs_err + EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; + } + } + } else { + for (int64_t i = 0; i < size; ++i) { + EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i; + } + } + } +}; + template <> struct TensorCheck { void operator()(const Tensor& expected, From 914c07ab9f66f98b975bc0c84ea2479d52240cc4 Mon Sep 17 00:00:00 2001 From: n1harika Date: Tue, 14 Oct 2025 21:24:08 -0700 Subject: [PATCH 108/138] [OVEP] Enable Session Option to Stop Context Sharing (#822) This PR enables the session option: StopShareEpContexts ("ep.stop_share_ep_contexts" = 1/0), to explicitly clear the globally shared EP context after a session completes. This allows controlled cleanup of shared weights, metadata, its file path, and other cached artifacts stored in SharedContext. For example, for model abc and model xyz: model_abc(ep.share_ep_contexts=1, ep.stop_share_ep_contexts=1) After this session, all existing shared context data is cleared. model_xyz(ep.share_ep_contexts=1) This model starts with a fresh shared context, without any leftover metadata or weights from model_abc. Fixed format and unique pointer handling. --- onnxruntime/core/providers/openvino/contexts.h | 12 ++++++++++++ .../openvino/openvino_execution_provider.cc | 7 +++++++ .../providers/openvino/openvino_provider_factory.cc | 1 + 3 files changed, 20 insertions(+) diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 07b09899ac214..a0dc33ae657c8 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -61,11 +61,22 @@ class SharedContext : public WeakSingleton { size_t weights_size_; }; + void clear() { + metadata.clear(); + metadata_filepath.clear(); + external_weight_filename.clear(); + mapped_weights.reset(); + } + fs::path external_weight_filename; std::unique_ptr mapped_weights; Metadata::Map metadata; fs::path metadata_filepath; } shared_weights; + + void clear() { + shared_weights.clear(); + } }; using config_t = std::map; @@ -109,6 +120,7 @@ struct ProviderInfo { bool so_context_embed_mode{false}; // ORT session option bool so_share_ep_contexts{false}; // ORT session option fs::path so_context_file_path{}; // ORT session option + bool so_stop_share_ep_contexts{false}; // ORT session option const ConfigOptions* config_options{NULL}; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index a0fa885cbfc38..ee5298d5b08e2 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -65,6 +65,7 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { backend_manager.ShutdownBackendManager(); } backend_managers_.clear(); + shared_context_.reset(); } std::vector> @@ -214,6 +215,12 @@ common::Status OpenVINOExecutionProvider::Compile( file << metadata; } + if (session_context_.so_stop_share_ep_contexts) { + if (shared_context_) { + shared_context_->clear(); + } + } + return status; } diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 1a10d9849d5cc..f26da37fa7d7e 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -28,6 +28,7 @@ void ParseConfigOptions(ProviderInfo& pi) { pi.so_context_embed_mode = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; pi.so_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; pi.so_context_file_path = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + pi.so_stop_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionStopShareEpContexts, "0") == "1"; if (pi.so_share_ep_contexts) { ov::AnyMap map; From 19ebc1f2266358d9ae80ef79f75b1e5c59c83c17 Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Fri, 17 Oct 2025 04:42:05 +0530 Subject: [PATCH 109/138] CVS-174008: Enable ETW Tracing for OVEP (#827) * Update: Enable OVEP ETW Tracing & Logging * fix: Refactor code & fix load_config tracing logic * fix: refactor logging logic * fix: amend review changes * fix: amend review changes. * fix: amend string parsing & add linking advapi32 in cmake * fix: refine options parsing * fix: perform atomic add for the global session counter * add: enable ep.stop_share_ep_contexts sess opt logging * fix:lint fixes * fix: ort wprp file * fix: OVTracing renamed * refactor logging of runtime options * lint fixes * fix: coverity fixes * fix: fix event options logging * Update onnxruntime/core/providers/openvino/ov_tracing.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Revert "Update onnxruntime/core/providers/openvino/ov_tracing.cc" This reverts commit 05920e274ed9f5a5e4cf005de29df7216b110e32. --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Mayuresh M Varerkar --- cmake/onnxruntime_providers_openvino.cmake | 5 + .../core/providers/openvino/contexts.h | 21 +- .../openvino/openvino_execution_provider.cc | 7 + .../openvino/openvino_execution_provider.h | 9 + .../core/providers/openvino/ov_tracing.cc | 228 ++++++++++++++++++ .../core/providers/openvino/ov_tracing.h | 64 +++++ ort.wprp | 6 + 7 files changed, 339 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/core/providers/openvino/ov_tracing.cc create mode 100644 onnxruntime/core/providers/openvino/ov_tracing.h diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index 5a831a106ae08..8c8d58c30f594 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -51,6 +51,11 @@ target_include_directories(onnxruntime_providers_openvino SYSTEM PUBLIC ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${OpenVINO_INCLUDE_DIR} ${OPENVINO_INCLUDE_DIR_LIST} ${PYTHON_INCLUDE_DIRS} $ENV{OPENCL_INCS} $ENV{OPENCL_INCS}/../../cl_headers/) target_link_libraries(onnxruntime_providers_openvino ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${OPENVINO_LIB_LIST} ${ABSEIL_LIBS} Eigen3::Eigen onnx_proto) + # ETW TraceLogging depends on Advapi32 on Windows + if(WIN32) + target_link_libraries(onnxruntime_providers_openvino advapi32) + endif() + target_compile_definitions(onnxruntime_providers_openvino PRIVATE FILE_NAME=\"onnxruntime_providers_openvino.dll\") if(MSVC) diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index a0dc33ae657c8..4a90fc2ccccc1 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -127,9 +127,20 @@ struct ProviderInfo { "enable_causallm", "disable_dynamic_shapes", "reshape_input", "layout"}; }; +struct RuntimeConfig { + std::unordered_map options; + std::optional Get(const std::string& key) const { + auto it = options.find(key); + return it != options.end() ? std::optional{it->second} : std::nullopt; + } +}; + // Holds context applicable to the entire EP instance. struct SessionContext : ProviderInfo { - SessionContext(const ProviderInfo& info) : ProviderInfo{info} {} + SessionContext(const ProviderInfo& info) : ProviderInfo{info} { + InitRuntimeConfig(); + } + std::vector deviceAvailableList = {true, true, true, true, true, true, true, true}; std::filesystem::path onnx_model_path_name; uint32_t onnx_opset_version{0}; @@ -137,6 +148,14 @@ struct SessionContext : ProviderInfo { mutable bool has_external_weights = false; // Value is set to mutable to modify from capability const std::vector OpenVINO_Version = {OPENVINO_VERSION_MAJOR, OPENVINO_VERSION_MINOR}; const std::string openvino_sdk_version = std::to_string(OPENVINO_VERSION_MAJOR) + "." + std::to_string(OPENVINO_VERSION_MINOR); + RuntimeConfig runtime_config; + + private: + void InitRuntimeConfig() { + if (config_options) { + runtime_config.options = config_options->GetConfigOptionsMap(); + } + } }; // Holds context specific to subgraph. diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index ee5298d5b08e2..049af81c9ffb2 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -21,6 +21,8 @@ namespace onnxruntime { namespace openvino_ep { +std::atomic OpenVINOExecutionProvider::global_session_counter_{0}; + // Parking this code here for now before it's moved to the factory #if defined OPENVINO_CONFIG_HETERO || defined OPENVINO_CONFIG_MULTI || defined OPENVINO_CONFIG_AUTO static std::vector parseDevices(const std::string& device_string, @@ -58,6 +60,11 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const ProviderInfo& info, s shared_context_{std::move(shared_context)}, ep_ctx_handle_{session_context_.openvino_sdk_version, *GetLogger()} { InitProviderOrtApi(); +#ifdef _WIN32 + session_id_ = global_session_counter_.fetch_add(1) + 1; + // Trace all runtime options (includes both session and provider options) + OVTracing::Instance().LogAllRuntimeOptions(session_id_, session_context_); +#endif } OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 020aec16e507c..a375a9ee788bd 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -11,10 +11,15 @@ #include #include #include +#include #include "core/providers/openvino/backend_manager.h" #include "core/providers/openvino/contexts.h" +#ifdef _WIN32 +#include "core/providers/openvino/ov_tracing.h" +#endif + namespace onnxruntime { namespace openvino_ep { @@ -74,6 +79,10 @@ class OpenVINOExecutionProvider : public IExecutionProvider { std::shared_ptr shared_context_; std::list backend_managers_; // EP session owns the backend objects EPCtxHandler ep_ctx_handle_; + + // Tracing and session tracking + uint32_t session_id_{0}; + static std::atomic global_session_counter_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/ov_tracing.cc b/onnxruntime/core/providers/openvino/ov_tracing.cc new file mode 100644 index 0000000000000..79109552f3df6 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_tracing.cc @@ -0,0 +1,228 @@ +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/openvino/ov_tracing.h" + +#ifdef _WIN32 +#include +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 26440) +#endif +#include +#include +#include "core/platform/windows/TraceLoggingConfig.h" + +TRACELOGGING_DEFINE_PROVIDER( + ov_tracing_provider_handle, + "Intel.ML.ONNXRuntime.OpenVINO", + // {"b5a8c2e1-4d7f-4a3b-9c2e-1f8e5a6b7c9d"} + (0xb5a8c2e1, 0x4d7f, 0x4a3b, 0x9c, 0x2e, 0x1f, 0x8e, 0x5a, 0x6b, 0x7c, 0x9d), + TraceLoggingOptionMicrosoftTelemetry()); + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +namespace { +std::string EscapeJsonString(const std::string& input) { + std::string escaped; + // Reserve extra space for escaping + escaped.reserve(input.size() + input.size() / 5); + + for (char c : input) { + switch (c) { + case '\"': + escaped += "\\\""; + break; + case '\\': + escaped += "\\\\"; + break; + case '\b': + escaped += "\\b"; + break; + case '\f': + escaped += "\\f"; + break; + case '\n': + escaped += "\\n"; + break; + case '\r': + escaped += "\\r"; + break; + case '\t': + escaped += "\\t"; + break; + default: + if (static_cast(c) < 0x20) { + char unicode_escape[7]; + sprintf_s(unicode_escape, sizeof(unicode_escape), "\\u%04x", static_cast(c)); + escaped += unicode_escape; + } else { + escaped += c; + } + break; + } + } + return escaped; +} +} // namespace + +namespace onnxruntime { +namespace openvino_ep { + +std::mutex OVTracing::mutex_; +std::mutex OVTracing::provider_change_mutex_; +uint32_t OVTracing::global_register_count_ = 0; +bool OVTracing::enabled_ = true; +UCHAR OVTracing::level_ = 0; +UINT64 OVTracing::keyword_ = 0; +std::vector OVTracing::callbacks_; +std::mutex OVTracing::callbacks_mutex_; + +OVTracing::OVTracing() { + std::lock_guard lock(mutex_); + if (global_register_count_ == 0) { + HRESULT hr = TraceLoggingRegisterEx(ov_tracing_provider_handle, ORT_TL_EtwEnableCallback, nullptr); + if (SUCCEEDED(hr)) { + global_register_count_ += 1; + } + } +} + +OVTracing::~OVTracing() noexcept { + // Clean up TraceLogging, only hold mutex_ + try { + std::lock_guard lock(mutex_); + if (global_register_count_ > 0) { + global_register_count_ -= 1; + if (global_register_count_ == 0) { + TraceLoggingUnregister(ov_tracing_provider_handle); + } + } + } catch (...) { + // Suppress exceptions in destructor + } + + // Clean up callbacks, only hold callbacks_mutex_ + try { + std::lock_guard lock_callbacks(callbacks_mutex_); + callbacks_.clear(); + } catch (...) { + // Suppress exceptions in destructor + } +} + +OVTracing& OVTracing::Instance() { + static OVTracing instance; + return instance; +} + +bool OVTracing::IsEnabled() const { + std::lock_guard lock(provider_change_mutex_); + return enabled_; +} + +UCHAR OVTracing::Level() const { + std::lock_guard lock(provider_change_mutex_); + return level_; +} + +UINT64 OVTracing::Keyword() const { + std::lock_guard lock(provider_change_mutex_); + return keyword_; +} + +void OVTracing::LogAllRuntimeOptions(uint32_t session_id, const SessionContext& ctx) const { + if (!IsEnabled()) return; + + // Log OpenVINO SDK version separately + TraceLoggingWrite(ov_tracing_provider_handle, "OV.SDK.Version", + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingUInt32(session_id, "session_id"), + TraceLoggingString(ctx.openvino_sdk_version.c_str(), "openvino_sdk_version")); + + constexpr std::string_view provider_prefix = "ep.openvinoexecutionprovider."; + std::ostringstream provider_opts; + std::ostringstream session_opts; + bool provider_first = true; + bool session_first = true; + + provider_opts << "{"; + session_opts << "{"; + + // Segregate options based on prefix + for (const auto& [key, value] : ctx.runtime_config.options) { + if (!value.empty()) { + if (key.starts_with(provider_prefix)) { + // Provider option + if (!provider_first) provider_opts << ","; + provider_opts << "\"" << key << "\":\"" << EscapeJsonString(value) << "\""; + provider_first = false; + } else { + // Session option + if (!session_first) session_opts << ","; + session_opts << "\"" << key << "\":\"" << EscapeJsonString(value) << "\""; + session_first = false; + } + } + } + + provider_opts << "}"; + session_opts << "}"; + + // Log provider options only if there are any + if (!provider_first) { + TraceLoggingWrite(ov_tracing_provider_handle, "OVEP.Provider.Options", + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingUInt32(session_id, "session_id"), + TraceLoggingString(provider_opts.str().c_str(), "provider_options")); + } + + // Log session options only if there are any + if (!session_first) { + TraceLoggingWrite(ov_tracing_provider_handle, "OVEP.Session.Options", + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingUInt32(session_id, "session_id"), + TraceLoggingString(session_opts.str().c_str(), "session_options")); + } +} + +void OVTracing::RegisterInternalCallback(const EtwInternalCallback& callback) { + std::lock_guard lock_callbacks(callbacks_mutex_); + callbacks_.push_back(&callback); +} + +void OVTracing::UnregisterInternalCallback(const EtwInternalCallback& callback) { + std::lock_guard lock_callbacks(callbacks_mutex_); + auto new_end = std::remove_if(callbacks_.begin(), callbacks_.end(), + [&callback](const EtwInternalCallback* ptr) { + return ptr == &callback; + }); + callbacks_.erase(new_end, callbacks_.end()); +} + +void NTAPI OVTracing::ORT_TL_EtwEnableCallback( + _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, _In_ UCHAR Level, _In_ ULONGLONG MatchAnyKeyword, + _In_ ULONGLONG MatchAllKeyword, _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext) { + { + std::lock_guard lock(provider_change_mutex_); + enabled_ = (IsEnabled != 0); + level_ = Level; + keyword_ = MatchAnyKeyword; + } + // Release lock before invoking callbacks to prevent deadlock + InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); +} + +void OVTracing::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext) { + std::lock_guard lock_callbacks(callbacks_mutex_); + for (const auto& callback : callbacks_) { + (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } +} + +} // namespace openvino_ep +} // namespace onnxruntime + +#endif // defined(_WIN32) diff --git a/onnxruntime/core/providers/openvino/ov_tracing.h b/onnxruntime/core/providers/openvino/ov_tracing.h new file mode 100644 index 0000000000000..b558695d6f7c7 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_tracing.h @@ -0,0 +1,64 @@ +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#ifdef _WIN32 +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include "core/providers/openvino/contexts.h" + +TRACELOGGING_DECLARE_PROVIDER(ov_tracing_provider_handle); + +namespace onnxruntime { +namespace openvino_ep { + +class OVTracing { + public: + static OVTracing& Instance(); + bool IsEnabled() const; + unsigned char Level() const; + UINT64 Keyword() const; + + void LogAllRuntimeOptions(uint32_t session_id, const SessionContext& ctx) const; + + using EtwInternalCallback = std::function; + static void RegisterInternalCallback(const EtwInternalCallback& callback); + static void UnregisterInternalCallback(const EtwInternalCallback& callback); + + private: + OVTracing(); + ~OVTracing(); + OVTracing(const OVTracing&) = delete; + OVTracing& operator=(const OVTracing&) = delete; + OVTracing(OVTracing&&) = delete; + OVTracing& operator=(OVTracing&&) = delete; + + static std::mutex mutex_; + static uint32_t global_register_count_; + static bool enabled_; + static std::vector callbacks_; + static std::mutex callbacks_mutex_; + static std::mutex provider_change_mutex_; + static UCHAR level_; + static ULONGLONG keyword_; + + static void InvokeCallbacks(LPCGUID, ULONG, UCHAR, ULONGLONG, ULONGLONG, PEVENT_FILTER_DESCRIPTOR, PVOID); + static void NTAPI ORT_TL_EtwEnableCallback(_In_ LPCGUID, _In_ ULONG, _In_ UCHAR, _In_ ULONGLONG, + _In_ ULONGLONG, _In_opt_ PEVENT_FILTER_DESCRIPTOR, _In_opt_ PVOID); +}; + +} // namespace openvino_ep +} // namespace onnxruntime + +#endif // defined(_WIN32) diff --git a/ort.wprp b/ort.wprp index 5dd2332cb1f9f..99a5d72e597e7 100644 --- a/ort.wprp +++ b/ort.wprp @@ -17,6 +17,11 @@ + + + + @@ -24,6 +29,7 @@ + From 25c6976b6edd224db7fa86f1985d45da2f49e116 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Fri, 17 Oct 2025 11:43:30 -0700 Subject: [PATCH 110/138] address PR review --- .../core/providers/openvino/ov_interface.cc | 2 +- .../openvino/ov_stateful_patch_utils.cc | 32 ++++++++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 99e310185e9e4..e97bbaceee4e2 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -365,7 +365,7 @@ StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, s // check if there is input_ids tensors and if the tensor type is int64, // because logic prefill_use_full_chat_history is only for specific inputs and data type auto input_ids_opt = FindTensor("input_ids"); - if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() != ov::element::i64) { + if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() == ov::element::i64) { prefill_use_full_chat_history = true; } } diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index f86d2d54fc381..4c5edb8d4283e 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -59,6 +59,17 @@ bool ModelHasInputOutputNames(std::shared_ptr model, const std::strin return false; } +std::string GetInputOutputName(std::shared_ptr ov_model, + const std::vector& candidate_names) { + for (const auto& name : candidate_names) { + if (ModelHasInputOutputNames(ov_model, name)) { + return name; + } + } + // Return the first candidate as default if none are found + return candidate_names.empty() ? "" : candidate_names[0]; +} + void FuseCacheReorder(std::shared_ptr ov_model, std::vector& not_kv_inputs, const std::vector& key_value_input_names, @@ -67,18 +78,15 @@ void FuseCacheReorder(std::shared_ptr ov_model, throw std::runtime_error("Model already has fused cache"); } - std::string main_input_name = "inputs_embeds"; - if (ModelHasInputOutputNames(ov_model, "input_ids")) { - main_input_name = "input_ids"; - } - - if (ModelHasInputOutputNames(ov_model, "input_hidden_states")) { - main_input_name = "input_hidden_states"; - } + // Define input name candidates in priority order + const std::vector input_name_candidates = { + "inputs_embeds", // Default fallback + "input_ids", // Most common + "input_hidden_states", // Alternative + "/model/embed_tokens/Gather_output_0" // Specific model type + }; - if (ModelHasInputOutputNames(ov_model, "/model/embed_tokens/Gather_output_0")) { - main_input_name = "/model/embed_tokens/Gather_output_0"; - } + std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates); auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; @@ -131,7 +139,7 @@ void PatchStatefulDecoder(std::shared_ptr model) { std::vector not_kv_inputs; const auto& params = model->get_parameters(); bool found = false; - for (auto i = 0; i < params.size(); i++) { + for (size_t i = 0; i < params.size(); i++) { auto param_name = params.at(i)->output(0).get_any_name(); if (param_name.find("key_values") != std::string::npos) { key_value_input_names.push_back(param_name); From 397c61bb50ab45bfd722a1b2ca641f64eebf086b Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Wed, 22 Oct 2025 16:30:38 -0700 Subject: [PATCH 111/138] CVS-174585: Memory map shared weights when possible (#829) * Memory map shared weights when possible * Update onnxruntime/core/providers/openvino/backend_utils.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update onnxruntime/core/providers/openvino/backend_utils.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../core/providers/openvino/backend_utils.cc | 78 ++++++++++++++----- .../core/providers/openvino/contexts.h | 8 ++ .../core/providers/openvino/ov_interface.h | 11 +++ 3 files changed, 78 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 7027861f0c4dc..7201c47a805e3 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -20,11 +20,11 @@ using Exception = ov::Exception; namespace onnxruntime { namespace openvino_ep { -SharedContext::SharedWeights::WeightsFile::WeightsFile(std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary) { +SharedContext::SharedWeights::WeightsFile::WeightsFile(std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary), file_path_(filename) { try { file_.exceptions(std::ifstream::failbit | std::ifstream::badbit); - weights_size_ = file_.seekg(0, std::ios::end).tellg(); - } catch (std::ifstream::failure& e) { + weights_size_ = std::filesystem::file_size(filename); + } catch (const std::exception& e) { ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what()); } } @@ -35,6 +35,32 @@ void SharedContext::SharedWeights::WeightsFile::load_weights(size_t file_offset, file_.read(reinterpret_cast(data), size); } +void* SharedContext::SharedWeights::WeightsFile::TryGetOrCreateDeviceMapping(std::optional& remote_context) { + std::string dev_name{}; + if (remote_context) { + dev_name = remote_context->get_device_name(); + } + + auto [it, inserted] = imported_device_tensors_.emplace(dev_name, MappingContainer{}); + if (inserted) { + if (dev_name == "NPU") { +#if OPENVINO_VERSION_AT_LEAST(2025, 3) + // try to import the memory mapped file to remote tensor + ORT_ENFORCE(remote_context, "Error: Remote context is required for NPU device."); + auto npu_context = remote_context->as(); + auto&& l0_tensor = npu_context.create_tensor(ov::element::Type_t::u8, {weights_size_}, ov::intel_npu::FileDescriptor(file_path_)); + it->second = MappingContainer{.ptr_ = l0_tensor.get(), .tensor_ = l0_tensor}; +#endif + } else if (dev_name.empty()) { + // CPU/virtual device case, create a CPU tensor memory mapped from file + auto&& mmaped_tensor = ov::read_tensor_data(file_path_); + it->second = MappingContainer{.ptr_ = mmaped_tensor.data(), .tensor_ = mmaped_tensor}; + } + } + + return it->second.ptr_; +} + std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) { try { stream << metadata.size(); @@ -405,29 +431,43 @@ ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt void CreateOVTensors(const std::string& device_name, SharedContext::SharedWeights::Metadata::Map& metadata_map, SharedContext::SharedWeights::WeightsFile& weights) { + // Get remote context if available + std::optional opt_remote_ctx; + try { + opt_remote_ctx = OVCore::Get()->core.get_default_context(device_name); + } catch (const std::exception&) { + // Remote context not available + } + for (auto& [key, value] : metadata_map) { if (value.tensor) continue; // Get element data type auto onnx_element_type = (ONNX_NAMESPACE::TensorProto_DataType)value.element_type; - - ov::element::Type ov_elementType = GetOpenVINOElementType(onnx_element_type); // Map to OpenVINO data type - - // Create OpenVINO Tensor - if (device_name == "NPU") { - // Use remote tensors - auto npu_context = OVCore::Get()->core.get_default_context("NPU").as(); - auto&& remote_tensor = npu_context.create_l0_host_tensor(ov_elementType, value.dimensions, ov::intel_npu::TensorType::INPUT); - - // Copy data to remote tensor - weights.load_weights(value.data_offset, remote_tensor.get(), value.size); - value.tensor = std::make_shared(remote_tensor); + ov::element::Type ov_elementType = GetOpenVINOElementType(onnx_element_type); + + // Try to get memory-mapped weights + ov::Tensor tensor; + uint8_t* mmaped_weights = static_cast(weights.TryGetOrCreateDeviceMapping(opt_remote_ctx)); + + if (mmaped_weights) { + // We have memory mapped weights. Create a Tensor view into it for this value. + ORT_ENFORCE(value.data_offset < weights.Size() && + value.size <= weights.Size() && + (value.data_offset <= weights.Size() - value.size), + "File offset + size outside of external initializer file"); + void* mmapped_offset = static_cast(mmaped_weights + value.data_offset); + tensor = ov::Tensor(ov_elementType, value.dimensions, mmapped_offset); } else { - // Use vanilla tensors - value.tensor = std::make_shared(ov_elementType, value.dimensions); - weights.load_weights(value.data_offset, value.tensor->data(), value.size); + ORT_ENFORCE(opt_remote_ctx, "Expected either memory-mapped weights or a valid remote context, but neither is available for device: ", device_name); + // Can't mmap the file to device tensor, create a host tensor and copy the data + tensor = opt_remote_ctx->create_host_tensor(ov_elementType, value.dimensions); + ORT_ENFORCE(tensor.get_byte_size() == value.size, "Remote tensor size mismatch"); + weights.load_weights(value.data_offset, tensor.data(), value.size); } - ORT_ENFORCE(value.tensor->get_byte_size() == value.size, "Unexpected tensor size mismatch"); + + ORT_ENFORCE(tensor.get_byte_size() == value.size, "Unexpected tensor size mismatch"); + value.tensor = std::make_shared(std::move(tensor)); } } diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 4a90fc2ccccc1..edd9f176658f8 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -55,10 +55,18 @@ class SharedContext : public WeakSingleton { explicit WeightsFile(std::filesystem::path filename); void load_weights(size_t file_offset, void* data, size_t size); + void* TryGetOrCreateDeviceMapping(std::optional& remote_context); + size_t Size() const { return weights_size_; } private: std::ifstream file_; + std::filesystem::path file_path_; size_t weights_size_; + struct MappingContainer { + void* ptr_{nullptr}; + ov::Tensor tensor_; + }; + std::map imported_device_tensors_; }; void clear() { diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 5f8fb36c1cbec..44ec4a235c3e9 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -21,6 +21,17 @@ #include +// Helper macro to test OpenVINO version at compile time. +// Usage: #if OPENVINO_VERSION_AT_LEAST(2025, 3) +// Falls back to 0 if OPENVINO_VERSION_MAJOR/MINOR are not defined. +#if defined(OPENVINO_VERSION_MAJOR) && defined(OPENVINO_VERSION_MINOR) +#define OPENVINO_VERSION_AT_LEAST(major, minor) \ + ((OPENVINO_VERSION_MAJOR > (major)) || \ + (OPENVINO_VERSION_MAJOR == (major) && OPENVINO_VERSION_MINOR >= (minor))) +#else +#define OPENVINO_VERSION_AT_LEAST(major, minor) 0 +#endif + namespace onnxruntime { namespace openvino_ep { class OVCore; From 6d4106556b22c9ba52c2d3edd4c36bc9f511bf6f Mon Sep 17 00:00:00 2001 From: Jaswanth51 Date: Mon, 27 Oct 2025 01:02:22 -0700 Subject: [PATCH 112/138] Run lintrunner and fix formatting issues (#836) --- .../providers/openvino/backend_manager.cc | 26 +++++++------------ .../core/providers/openvino/ov_interface.h | 4 +-- .../openvino/openvino_ep_ext_init.cc | 6 ++--- onnxruntime/test/unittest_util/checkers.cc | 2 +- 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 4e9c0f912c825..74999ab10a67d 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -85,7 +85,7 @@ BackendManager::BackendManager(SessionContext& session_context, subgraph_context_.subgraph_name = fused_node.Name(); if (ModelHasSymbolicInputDims(subgraph)) { - subgraph_context_.has_dynamic_input_shape = true; + subgraph_context_.has_dynamic_input_shape = true; } ptr_stream_t model_stream; @@ -171,9 +171,7 @@ BackendManager::BackendManager(SessionContext& session_context, exception_str.find("intel_npu") != std::string::npos) { // Handle NPU device related errors #ifndef NDEBUG - std::string suffix = session_context_.so_disable_cpu_ep_fallback ? - "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n" : - "\nModel needs to be recompiled\n"; + std::string suffix = session_context_.so_disable_cpu_ep_fallback ? "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n" : "\nModel needs to be recompiled\n"; ORT_THROW(exception_str + suffix); #else std::string error_message = "UNKNOWN NPU ERROR"; @@ -187,9 +185,7 @@ BackendManager::BackendManager(SessionContext& session_context, if (std::regex_search(exception_str, matches, error_code_pattern)) { error_code = matches[0]; } - std::string suffix = session_context_.so_disable_cpu_ep_fallback ? - "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n" : - "\nModel needs to be recompiled\n"; + std::string suffix = session_context_.so_disable_cpu_ep_fallback ? "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n" : "\nModel needs to be recompiled\n"; throw std::runtime_error(error_message + ", " + error_code + suffix); #endif } else { @@ -645,12 +641,11 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, const bool include_initializer_data_in_proto = true; #endif - auto model = subgraph.CreateModel(logger); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - subgraph.ToProto(*model_proto->mutable_graph(), /*include_initializers*/true, - /*include_outer_scope_args*/true, /*execution_order*/0, /*include_initializer_data*/include_initializer_data_in_proto); + subgraph.ToProto(*model_proto->mutable_graph(), /*include_initializers*/ true, + /*include_outer_scope_args*/ true, /*execution_order*/ 0, /*include_initializer_data*/ include_initializer_data_in_proto); print_model_proto_duration(); @@ -684,9 +679,9 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, << ", data_type: " << src_init->data_type() << ", raw_data size: " << src_init->raw_data().size(); if (src_init->raw_data().size() > 0) - SetExternalDataFields(proto_init, src_init->raw_data().data(), src_init->raw_data().size()); + SetExternalDataFields(proto_init, src_init->raw_data().data(), src_init->raw_data().size()); else - LOGS(logger, VERBOSE) << "Initializer has empty raw_data: skipping initializer '" << src_init->name() << "'..."; + LOGS(logger, VERBOSE) << "Initializer has empty raw_data: skipping initializer '" << src_init->name() << "'..."; } else if (onnxruntime::utils::HasExternalDataInMemory(*src_init)) { auto it_ext = external_initializers_offset_and_length.find(name); if (it_ext == external_initializers_offset_and_length.end()) { @@ -889,7 +884,7 @@ void BackendManager::Compute(OrtKernelContext* context) { } } else { std::string exception_str = ex.what(); - if (session_context_.so_disable_cpu_ep_fallback){ + if (session_context_.so_disable_cpu_ep_fallback) { std::string error_message = "UNKNOWN NPU ERROR"; std::string error_code = "code 0x0"; std::regex error_message_pattern(R"(\bZE_\w*\b)"); @@ -901,10 +896,9 @@ void BackendManager::Compute(OrtKernelContext* context) { if (std::regex_search(exception_str, matches, error_code_pattern)) { error_code = matches[0]; } - std::string suffix = "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n" ; + std::string suffix = "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n"; throw std::runtime_error(error_message + ", " + error_code + suffix); - } - else{ + } else { ORT_THROW(exception_str); } } diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 44ec4a235c3e9..d5d4bd1af0c6a 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -25,8 +25,8 @@ // Usage: #if OPENVINO_VERSION_AT_LEAST(2025, 3) // Falls back to 0 if OPENVINO_VERSION_MAJOR/MINOR are not defined. #if defined(OPENVINO_VERSION_MAJOR) && defined(OPENVINO_VERSION_MINOR) -#define OPENVINO_VERSION_AT_LEAST(major, minor) \ - ((OPENVINO_VERSION_MAJOR > (major)) || \ +#define OPENVINO_VERSION_AT_LEAST(major, minor) \ + ((OPENVINO_VERSION_MAJOR > (major)) || \ (OPENVINO_VERSION_MAJOR == (major) && OPENVINO_VERSION_MINOR >= (minor))) #else #define OPENVINO_VERSION_AT_LEAST(major, minor) 0 diff --git a/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc b/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc index 21ec61c2d2e3f..139d9c0aaf2b1 100644 --- a/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc +++ b/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc @@ -52,7 +52,7 @@ auto ProbeDevice(const std::string& device) { } return is_present[device]; } -} // namespace detail +} // namespace namespace onnxruntime { namespace test { @@ -62,7 +62,7 @@ TEST_P(OVEP_ExtInit_Tests, DISABLED_ModelFromExtInit) { const auto& device = GetParam(); if (!ProbeDevice(device)) GTEST_SKIP() << device + " is not available on this machine"; - + // Model and weights file paths const std::string model_path = "ovep_ext_init_test.onnx"; const std::string weights_path = "ovep_ext_init_test.onnx.data"; @@ -174,7 +174,7 @@ TEST_P(OVEP_ExtInit_Tests, DISABLED_ModelFromExtInit) { Ort::SessionOptions session_options; session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); session_options.SetIntraOpNumThreads(1); - std::unordered_map ov_options = { {"device_type", device } }; + std::unordered_map ov_options = {{"device_type", device}}; session_options.AppendExecutionProvider_OpenVINO_V2(ov_options); session_options.AddExternalInitializersFromFilesInMemory(names_path, buffers, buffer_sizes); diff --git a/onnxruntime/test/unittest_util/checkers.cc b/onnxruntime/test/unittest_util/checkers.cc index 794bd24310cd1..d4b30cd11f1a0 100644 --- a/onnxruntime/test/unittest_util/checkers.cc +++ b/onnxruntime/test/unittest_util/checkers.cc @@ -383,7 +383,7 @@ struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, - const std::string& ) const { + const std::string&) const { const bool has_abs_err = params.absolute_error.has_value(); const bool has_rel_err = params.relative_error.has_value(); From eff6cac4d56947525434e0313d03d95760aa4b10 Mon Sep 17 00:00:00 2001 From: Yaru Du Date: Mon, 27 Oct 2025 23:09:06 +0000 Subject: [PATCH 113/138] CVS-175734- [OVEP GPU] add GQA in support list for GPU backend (#830) * support GQA on GPU * remove CPU/NPU support Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- onnxruntime/core/providers/openvino/ov_versions/data_ops.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index f848b89ed10c8..037cb6a1270ea 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -96,6 +96,7 @@ std::vector supported_op_mode = { {"Atanh", V_2020_4, {"CPU"}}, {"Atanh", V_2022_1, {"GPU"}}, {"Attention", V_2023_0, {"CPU", "GPU"}}, + {"GroupQueryAttention", V_2025_1, {"GPU"}}, {"AveragePool", V_2020_4, {"CPU", "GPU"}}, {"BatchNormalization", V_2020_4, {"CPU", "GPU"}}, {"BiasGelu", V_2023_0, {"CPU", "GPU"}}, From b1f77501d9906d881914543129b4aa409c17a625 Mon Sep 17 00:00:00 2001 From: Yaru Du Date: Tue, 28 Oct 2025 15:31:08 +0000 Subject: [PATCH 114/138] CVS-175737-[OVEP] Expose kvcache_rewind python api (#831) * expose rewind api through Python * address PR review --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../onnxruntime_inference_collection.py | 10 ++++ .../python/onnxruntime_pybind_state.cc | 49 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 4c3313046457c..91216473bcad2 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -397,6 +397,16 @@ def run_with_iobinding(self, iobinding, run_options=None): """ self._sess.run_with_iobinding(iobinding._iobinding, run_options) + def set_ep_dynamic_options(self, options: dict[str, str]): + """ + Set dynamic options for execution providers. + + :param options: Dictionary of key-value pairs where both keys and values are strings. + These options will be passed to the execution providers to modify + their runtime behavior. + """ + self._sess.set_ep_dynamic_options(options) + def get_tuning_results(self): return self._sess.get_tuning_results() diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index c548f3df4fb27..cdaea385f82ee 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2810,6 +2810,55 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); #endif }) + .def( + "set_ep_dynamic_options", [](PyInferenceSession* sess, const py::dict& options) { + std::vector keys; + std::vector values; + std::vector key_strings; + std::vector value_strings; + + // Reserve space to avoid reallocations + key_strings.reserve(options.size()); + value_strings.reserve(options.size()); + keys.reserve(options.size()); + values.reserve(options.size()); + + // Convert Python dict to C-style arrays + for (const auto& item : options) { + key_strings.emplace_back(py::str(item.first)); + value_strings.emplace_back(py::str(item.second)); + keys.push_back(key_strings.back().c_str()); + values.push_back(value_strings.back().c_str()); + } + + if (keys.empty()) { + ORT_THROW("No options were provided"); + } + + auto status = sess->GetSessionHandle()->SetEpDynamicOptions( + gsl::make_span(keys.data(), keys.size()), + gsl::make_span(values.data(), values.size())); + + if (!status.IsOK()) { + ORT_THROW("Failed to set EP dynamic options: " + status.ErrorMessage()); + } + }, + R"pbdoc(Set dynamic options for execution providers. + + Args: + options (dict): Dictionary of key-value pairs where both keys and values are strings. + These options will be passed to the execution providers to modify + their runtime behavior. + + Example: + session.set_ep_dynamic_options({ + "option1": "value1", + "option2": "value2" + }) + + Raises: + RuntimeError: If no options are provided or if setting the options fails. + )pbdoc") .def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void { #if !defined(ORT_MINIMAL_BUILD) std::vector tuning_results; From 513e198117e27d259ba1202223664634ebb59f1a Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Tue, 28 Oct 2025 11:14:54 -0700 Subject: [PATCH 115/138] update the keyword for matching key_value_input_names --- .../core/providers/openvino/ov_stateful_patch_utils.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index 4c5edb8d4283e..d5c946745d822 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -144,10 +144,10 @@ void PatchStatefulDecoder(std::shared_ptr model) { if (param_name.find("key_values") != std::string::npos) { key_value_input_names.push_back(param_name); found = true; - } else if (param_name.find("key") != std::string::npos) { + } else if (param_name.find("keys") != std::string::npos) { key_value_input_names.push_back(param_name); found = true; - } else if (param_name.find("value") != std::string::npos) { + } else if (param_name.find("values") != std::string::npos) { key_value_input_names.push_back(param_name); found = true; } From 20de366c399955d559836693e9f55e2cc484c2fb Mon Sep 17 00:00:00 2001 From: Mikhail Dvoretckii Date: Wed, 29 Oct 2025 17:54:37 +0000 Subject: [PATCH 116/138] CVS-175447-[OVEP] Add a check for type mismatches in QDQ stripping (#834) * [OVEP] Add a check for type mismatches in QDQ stripping When rewiring the graph after eliminating QDQ pairs, the runtime now checks whether the type matches before and after the eliminated nodes and inserts a Cast node if there is a mismatch. * Expand type transform * Limit output types to f32/f16, add const_cast * Apply null check suggestion Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../qdq_transformations/qdq_scales_fix.cpp | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index 47f6ab9a50a82..de0e8a97fb6b0 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -463,11 +463,35 @@ struct CustomGraph { } if (!is_prev_input) { - for (const auto& edge : output_edges) { + if (prev.node_ptr->OutputDefs()[0]->Type() != dq_node_ref.OutputDefs()[0]->Type()) { + NodeArg& output = original_graph.GetOrCreateNodeArg(prev.node_name + "_cast_0", dq_node_ref.OutputDefs()[0]->TypeAsProto()); + std::string cast_node_name = prev.node_ptr->OutputDefs()[0]->Name() + "_cast"; + InlinedVector input_args = {const_cast(prev.node_ptr->OutputDefs()[0])}; + InlinedVector output_args = {&output}; + Node& cast_node = original_graph.AddNode(cast_node_name, "Cast", "", input_args, output_args, nullptr, ""); + auto type_str = dq_node_ref.OutputDefs()[0]->Type(); + ORT_ENFORCE(type_str != nullptr, "Type string is null in QDQ scales fix."); + auto type_cast = type_str->find("tensor(float)") != std::string::npos ? onnx::TensorProto_DataType_FLOAT : onnx::TensorProto_DataType_FLOAT16; + ORT_ENFORCE((type_cast == onnx::TensorProto_DataType_FLOAT) || (type_str->find("tensor(float16)") != std::string::npos), + "QDQ type misalignment, expected float32 or float16 output"); + cast_node.AddAttribute("to", static_cast(type_cast)); original_graph.AddEdge(prev.node_ptr->Index(), - std::get<0>(edge), + cast_node.Index(), prev_output_index, - std::get<2>(edge)); + 0); + for (const auto& edge : output_edges) { + original_graph.AddEdge(cast_node.Index(), + std::get<0>(edge), + 0, + std::get<2>(edge)); + } + } else { + for (const auto& edge : output_edges) { + original_graph.AddEdge(prev.node_ptr->Index(), + std::get<0>(edge), + prev_output_index, + std::get<2>(edge)); + } } } } From d7ee534b90338d88806db8d08a6c0903dbc83d53 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Thu, 30 Oct 2025 14:08:18 -0700 Subject: [PATCH 117/138] optimize the code update --- .../openvino/ov_stateful_patch_utils.cc | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index d5c946745d822..abc5d9446c043 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -138,20 +138,16 @@ void PatchStatefulDecoder(std::shared_ptr model) { std::vector key_value_input_names; std::vector not_kv_inputs; const auto& params = model->get_parameters(); - bool found = false; + for (size_t i = 0; i < params.size(); i++) { - auto param_name = params.at(i)->output(0).get_any_name(); + auto param_name = params[i]->output(0).get_any_name(); if (param_name.find("key_values") != std::string::npos) { - key_value_input_names.push_back(param_name); - found = true; + key_value_input_names.push_back(param_name); } else if (param_name.find("keys") != std::string::npos) { - key_value_input_names.push_back(param_name); - found = true; + key_value_input_names.push_back(param_name); } else if (param_name.find("values") != std::string::npos) { - key_value_input_names.push_back(param_name); - found = true; - } - if (!found) { + key_value_input_names.push_back(param_name); + } else{ not_kv_inputs.push_back(param_name); } } From 3c1c4c3e0e74a6c342aa000eaaa0be6fcdc13136 Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Thu, 30 Oct 2025 15:15:48 -0700 Subject: [PATCH 118/138] revert original code which is functional --- .../openvino/ov_stateful_patch_utils.cc | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index abc5d9446c043..51160443aacc1 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -137,18 +137,28 @@ void MakeStateful(std::shared_ptr& ov_model, void PatchStatefulDecoder(std::shared_ptr model) { std::vector key_value_input_names; std::vector not_kv_inputs; - const auto& params = model->get_parameters(); - - for (size_t i = 0; i < params.size(); i++) { - auto param_name = params[i]->output(0).get_any_name(); - if (param_name.find("key_values") != std::string::npos) { - key_value_input_names.push_back(param_name); - } else if (param_name.find("keys") != std::string::npos) { - key_value_input_names.push_back(param_name); - } else if (param_name.find("values") != std::string::npos) { - key_value_input_names.push_back(param_name); - } else{ - not_kv_inputs.push_back(param_name); + + for (const ov::Output& input : model->inputs()) { + auto& names = input.get_names(); + + bool found = false; + for (auto& name : names) { + if (name.find("key_values") != std::string::npos) { + key_value_input_names.push_back(name); + found = true; + break; + } else if (name.find("keys") != std::string::npos) { + key_value_input_names.push_back(name); + found = true; + break; + } else if (name.find("values") != std::string::npos) { + key_value_input_names.push_back(name); + found = true; + break; + } + } + if (!found) { + not_kv_inputs.push_back(input.get_any_name()); } } From 7c1720d5505fa18e0c7b745b2b6b96e41c43a9cc Mon Sep 17 00:00:00 2001 From: Kotomi-Du Date: Thu, 30 Oct 2025 15:18:43 -0700 Subject: [PATCH 119/138] remove useless change --- onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index 51160443aacc1..ca4867b7d8ae4 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -137,7 +137,6 @@ void MakeStateful(std::shared_ptr& ov_model, void PatchStatefulDecoder(std::shared_ptr model) { std::vector key_value_input_names; std::vector not_kv_inputs; - for (const ov::Output& input : model->inputs()) { auto& names = input.get_names(); @@ -157,6 +156,7 @@ void PatchStatefulDecoder(std::shared_ptr model) { break; } } + if (!found) { not_kv_inputs.push_back(input.get_any_name()); } From 7aa5363296878853dedca96d50d1cbc739e9bd9d Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:10:48 +0530 Subject: [PATCH 120/138] [OVEP] support to run layout feature using python bindings (#846) --- onnxruntime/python/onnxruntime_pybind_state.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index cdaea385f82ee..704716e80eb1d 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1083,7 +1083,7 @@ static std::shared_ptr CreateExecutionProviderFactory ProviderOptions OV_provider_options_map; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", - "enable_causallm", "disable_dynamic_shapes", "reshape_input"}; + "enable_causallm", "disable_dynamic_shapes", "reshape_input", "layout"}; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { for (auto option : it->second) { @@ -1892,7 +1892,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra py::class_ py_sync_stream(m, "OrtSyncStream", R"pbdoc(Represents a synchronization stream for model inference.)pbdoc"); - py_sync_stream.def("get_handle", [](OrtSyncStream* stream) -> uintptr_t { + py_sync_stream.def("get_handle", [](OrtSyncStream* stream) -> uintptr_t { Ort::UnownedSyncStream ort_stream(stream); return reinterpret_cast(ort_stream.GetHandle()); }, R"pbdoc(SyncStream handle that can be converted to a string and added to SessionOptions)pbdoc"); @@ -2006,7 +2006,7 @@ for model inference.)pbdoc"); .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc") .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType { auto mem_type = mem_info->device.MemType(); - return (mem_type == OrtDevice::MemType::DEFAULT) ? + return (mem_type == OrtDevice::MemType::DEFAULT) ? OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc") .def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); }); @@ -2748,7 +2748,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) - .def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list { + .def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list { Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); auto inputs_mem_info = session.GetMemoryInfoForInputs(); py::list result; @@ -2757,7 +2757,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") result.append(py::cast(p_info, py::return_value_policy::reference)); } return result; }) - .def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list { + .def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list { Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); auto outputs_mem_info = session.GetMemoryInfoForOutputs(); py::list result; From fa68db19aed1322c1ca9ed6959ba1f495af9bd23 Mon Sep 17 00:00:00 2001 From: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Date: Fri, 7 Nov 2025 13:26:05 -0800 Subject: [PATCH 121/138] CVS-176081: Add support for nested maps to load_config parsing (#844) * openvino_provider_factory: Add nested map support to load_config parsing * ParseInnerMap: Add warning that unsupported json types will become fatal in the future * ParseInnerMap: address review comments * load_config: Throw error for unsupported JSON types --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../openvino/openvino_provider_factory.cc | 44 +++++++++++++------ 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index f26da37fa7d7e..298eb25713bec 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -188,6 +188,36 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio void ParseProviderOptions([[maybe_unused]] ProviderInfo& result, [[maybe_unused]] const ProviderOptions& config_options) {} +static void ParseInnerMap(const nlohmann::json& json_map, ov::AnyMap& inner_map, size_t level = 0) { + const size_t max_levels = 8; + if (level >= max_levels) { + ORT_THROW("ParseInnerMap: load_config can have only up to " + std::to_string(max_levels) + + " levels of nested maps. Current level = " + std::to_string(level)); + } + + if (!json_map.is_object()) { + ORT_THROW("ParseInnerMap: Expected an object as input"); + } + + for (auto& [inner_key, inner_value] : json_map.items()) { + if (inner_value.is_string()) { + inner_map[inner_key] = ov::Any(inner_value.get()); + } else if (inner_value.is_number_integer()) { + inner_map[inner_key] = ov::Any(inner_value.get()); + } else if (inner_value.is_number_float()) { + inner_map[inner_key] = ov::Any(inner_value.get()); + } else if (inner_value.is_boolean()) { + inner_map[inner_key] = ov::Any(inner_value.get()); + } else if (inner_value.is_object()) { + auto inner_inner_map = ov::AnyMap(); + ParseInnerMap(inner_value, inner_inner_map, level + 1); + inner_map[inner_key] = std::move(inner_inner_map); + } else { + ORT_THROW("load_config: unsupported JSON value type=" + std::string(inner_value.type_name()) + ", for key=" + inner_key); + } + } +} + // Initializes a ProviderInfo struct from a ProviderOptions map and a ConfigOptions map. static void ParseProviderInfo(const ProviderOptions& provider_options, const ConfigOptions* config_options, @@ -267,19 +297,7 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, ORT_THROW("Invalid JSON structure: Expected an object for device properties."); } - for (auto& [inner_key, inner_value] : value.items()) { - if (inner_value.is_string()) { - inner_map[inner_key] = inner_value.get(); - } else if (inner_value.is_number_integer()) { - inner_map[inner_key] = inner_value.get(); - } else if (inner_value.is_number_float()) { - inner_map[inner_key] = inner_value.get(); - } else if (inner_value.is_boolean()) { - inner_map[inner_key] = inner_value.get(); - } else { - LOGS_DEFAULT(WARNING) << "Unsupported JSON value type for key: " << inner_key << ". Skipping key."; - } - } + ParseInnerMap(value, inner_map); target_map[key] = std::move(inner_map); } } catch (const nlohmann::json::parse_error& e) { From 70acefe2489f44720778339baee1f86d06bb0355 Mon Sep 17 00:00:00 2001 From: bopeng1234 Date: Tue, 11 Nov 2025 14:26:00 +0800 Subject: [PATCH 122/138] [CVS-172796] fix bfloat16 conversion when single cast node to bfloat16 (#841) * disable bfloat16 conversion when single cast node to bfloat16, unit test case * Insert a Cast(To:BFloat16) before output node(bfloat16) to keep user use original bf16 outputs tensor * revert changes to add Cast Node, add statement to disable bfloat16 transform for OV CPU * remove bfloat16 silence conversion * remove bf16 testing and cpu support for openvino --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../providers/openvino/backend_manager.cc | 22 ---- .../openvino/ov_versions/data_ops.cc | 4 +- .../qdq_transformations/qdq_scales_fix.cpp | 56 --------- .../qdq_transformations/qdq_scales_fix.h | 5 - .../openvino_ep_bfloat16_pass_test.cc | 116 ------------------ 5 files changed, 1 insertion(+), 202 deletions(-) delete mode 100644 onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 74999ab10a67d..4a20847c0890c 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -389,18 +389,6 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { return false; } -static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) { - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (std::size_t i = 0; i < node_indices.size(); i++) { - gsl::not_null node(graph_viewer.GetNode(node_indices[i])); - for (auto& output : node->OutputDefs()) { - if (output->ToProto().type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) - return true; - } - } - return false; -} - static bool Is16BitTensor(const onnxruntime::NodeArg* node_arg) { const auto* type_proto = node_arg ? node_arg->TypeAsProto() : nullptr; return type_proto && type_proto->has_tensor_type() && @@ -598,16 +586,6 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; - } else if (IsModelBF16(subgraph)) { - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP bfloat16->float16 optimization pass is enabled"; - std::unique_ptr model; - Status status = bfloat16_fix::Transform(subgraph, logger, model); - auto model_proto = model->ToProto(); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - print_model_proto_duration(); - DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); - ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - return model_proto; } else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled"; diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 037cb6a1270ea..4156b45cd638a 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -561,9 +561,7 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { } auto dtype = type_proto->tensor_type().elem_type(); - // Enable bfloat16 -> float16 on-the-fly conversion - if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16 || - dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || + if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16) return true; if (is_initializer) { diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index de0e8a97fb6b0..a7b5c51882ff4 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -4,7 +4,6 @@ #include "qdq_scales_fix.h" #include "core/providers/openvino/ov_protobuf_utils.h" #include "core/framework/ort_value.h" -#include "core/common/float16.h" #include #include @@ -955,60 +954,5 @@ Status Transform(const GraphViewer& src_graph_viewer, return status; } } // namespace qdq_scales_fix - -namespace bfloat16_fix { -void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) { - for (auto& const_node : gen_graph.original_graph.Nodes()) { - auto node = const_cast(const_node); - if (node->OpType() == "Cast") { - for (auto& [name, const_attribute] : node->GetAttributes()) { - auto& attribute = const_cast(const_attribute); - if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT) - if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) - attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - } - } - for (auto& output : node->OutputDefs()) { - auto& output_proto = const_cast(output->ToProto().type()); - if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) - output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - } - } - - for (auto& node : gen_graph.original_graph.Nodes()) { - for (auto& input_def : node->InputDefs()) { - ORT_THROW_IF_ERROR(graph_utils::ConvertInMemoryDataToInline(gen_graph.original_graph, input_def->Name())); - } - } - - const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors(); - for (auto& [key, const_tensor_proto] : init_set) { - auto tensor_proto = const_cast(const_tensor_proto); - auto dt = tensor_proto->data_type(); - if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { - auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast(tensor_proto->mutable_raw_data()->data()) : nullptr; - if (raw_data) { - tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - std::int64_t size = 1; - for (int i = 0; i < tensor_proto->dims_size(); ++i) - size *= tensor_proto->dims()[i]; - for (std::int64_t i = 0; i < size; ++i) { - raw_data[i] = onnxruntime::MLFloat16(onnxruntime::BFloat16::FromBits(raw_data[i])).val; - } - } - } - } -} - -Status Transform(const GraphViewer& src_graph_viewer, - const logging::Logger& logger, - /*out*/ std::unique_ptr& model) { - auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model); - auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph()); - - replace_bf16_with_fp16(g); - return status; -} -} // namespace bfloat16_fix } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h index 2182850d96c43..c54c531e1bd40 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h @@ -15,10 +15,5 @@ Status Transform(const GraphViewer& src_graph, const logging::Logger& logger, /*out*/ std::unique_ptr& model); } -namespace bfloat16_fix { -Status Transform(const GraphViewer& src_graph, - const logging::Logger& logger, - /*out*/ std::unique_ptr& model); -} } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc deleted file mode 100644 index 105a35011a78d..0000000000000 --- a/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include - -#include "core/session/onnxruntime_cxx_api.h" -#include "core/common/float16.h" - -#include "test/util/include/test/test_environment.h" -#include "test/unittest_util/qdq_test_utils.h" - -#include "gtest/gtest.h" -#include "gmock/gmock.h" - -using namespace ONNX_NAMESPACE; -using namespace onnxruntime::logging; - -extern std::unique_ptr ort_env; - -class OVEP_BF16_Tests : public ::testing::TestWithParam {}; - -namespace detail { -auto ConstructModel() { - using namespace onnxruntime; - using namespace test; - - std::unordered_map domain_to_version; - domain_to_version[kOnnxDomain] = 19; - Model model("Bfloat16Tester", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); - - Graph& graph = model.MainGraph(); - ModelTestBuilder builder(graph); - auto dim = 4; - std::vector input_data(dim, 1.0f); - auto* input = builder.MakeInput({dim}, input_data); - builder.graph_.SetInputs({input}); - - auto* cast_to_bf16 = builder.MakeIntermediate(); - Node& cast_node = builder.AddNode("Cast", {input}, {cast_to_bf16}, ""); - cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)); - - std::vector weight_data(dim * dim); - for (std::size_t i = 0; i < weight_data.size(); ++i) - weight_data[i] = onnxruntime::BFloat16(static_cast(i % dim) / dim); - auto* weights = builder.MakeInitializer({dim, dim}, weight_data); - - auto* matmul_out = builder.MakeIntermediate(); - builder.AddNode("MatMul", {cast_to_bf16, weights}, {matmul_out}); - - std::vector weight_data_2(dim * dim); - for (std::size_t i = 0; i < weight_data_2.size(); ++i) - weight_data_2[i] = onnxruntime::BFloat16(static_cast(i % dim) / dim); - auto* weights_2 = builder.MakeInitializer({dim, dim}, weight_data_2); - - auto* matmul_out_2 = builder.MakeIntermediate(); - builder.AddNode("MatMul", {matmul_out, weights_2}, {matmul_out_2}); - - auto* output = builder.MakeOutput(); - Node& cast2_node = builder.AddNode("Cast", {matmul_out_2}, {output}); - cast2_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); - - builder.SetGraphOutputs(); - auto st = model.MainGraph().Resolve(); - if (st != Status::OK()) - throw std::runtime_error(st.ErrorMessage()); - return model; -} - -auto ProbeDevice(const std::string& device) { - static std::map is_present; - if (is_present.find(device) == is_present.end()) { - Ort::SessionOptions sessionOptions; - std::unordered_map ov_options; - ov_options["device_type"] = device; - try { - sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); - is_present[device] = true; - } catch (...) { - is_present[device] = false; - } - } - return is_present[device]; -} -} // namespace detail - -namespace onnxruntime { -namespace test { - -TEST_P(OVEP_BF16_Tests, TestModelConversion) { - Ort::SessionOptions sessionOptions; - std::unordered_map ov_options; - const auto& device = GetParam(); - if (!::detail::ProbeDevice(device)) - GTEST_SKIP() << device + " is not available on this machine"; - - ov_options["device_type"] = device; - auto model = ::detail::ConstructModel(); - sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); - - std::string model_data; - model.ToProto().SerializeToString(&model_data); - auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - try { - Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), sessionOptions); - } catch (...) { - FAIL(); - } -} -INSTANTIATE_TEST_SUITE_P(OVEP_Tests, - OVEP_BF16_Tests, - ::testing::Values("CPU", "GPU", "NPU")); -} // namespace test -} // namespace onnxruntime From 57431973da9529b167aab57d2f8216bab955cc10 Mon Sep 17 00:00:00 2001 From: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Date: Wed, 12 Nov 2025 22:08:58 -0800 Subject: [PATCH 123/138] ov_stateful_patch_utils: Remove NPUW WA for avoiding SinCos when context_len >= 2048 (#840) Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../core/providers/openvino/ov_stateful_patch_utils.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index ca4867b7d8ae4..7f276f565f795 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -319,13 +319,6 @@ void UpdateNPUConfig(ov::AnyMap& config, const KVAxesPosition& kv_pos, const KVD RenameKey(config, "PREFILL_HINT", "NPUW_LLM_PREFILL_HINT"); RenameKey(config, "GENERATE_CONFIG", "NPUW_LLM_GENERATE_CONFIG"); RenameKey(config, "GENERATE_HINT", "NPUW_LLM_GENERATE_HINT"); - - const size_t npuw_context_len_threshold = 2048; - if ((kv_desc.max_prompt_len + kv_desc.min_response_len) >= npuw_context_len_threshold) { - // This improves accuracy for generation sequences that exceed 2k tokens. - config["++NPUW_LLM_PREFILL_CONFIG"] = ov::AnyMap{{"NPUW_DEVICES", "NPU,CPU"}, {"NPUW_ONLINE_AVOID", "P:SinCos/NPU"}}; - config["++NPUW_LLM_GENERATE_CONFIG"] = ov::AnyMap{{"NPUW_DEVICES", "NPU,CPU"}, {"NPUW_ONLINE_AVOID", "P:SinCos/NPU"}}; - } } std::optional PopOptionNew(ov::AnyMap& config, const std::string& option_name) { From 10af800e9c4a924ec16fb416f9e0e14683bffdd6 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Thu, 13 Nov 2025 10:41:43 -0800 Subject: [PATCH 124/138] CVS-175880 Implement single bin (#847) * Implement single bin * Fix size mismatch for larger blobs * disallow embed mode + sharing * Tweak main context usage * Remove redundant stop share setting Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Only reject if share context generated * Fix up bin manager lifetimes * Fix ep context node path * Remove unnecessary initialized flag from BinManager * Refactor BackendManager and EPCtxHandler to use SharedContextManager, removing SharedBinManager references * tweak lock ordering * Tweak when we use the active shared context * Ensure all blobs are available at epctx export * Update onnxruntime/core/providers/openvino/openvino_execution_provider.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../providers/openvino/backend_manager.cc | 109 ++--- .../core/providers/openvino/backend_manager.h | 7 +- .../core/providers/openvino/backend_utils.cc | 212 +-------- .../core/providers/openvino/backend_utils.h | 5 - .../openvino/backends/basic_backend.cc | 15 +- .../core/providers/openvino/contexts.h | 88 +--- .../openvino/onnx_ctx_model_helper.cc | 136 ++++-- .../openvino/onnx_ctx_model_helper.h | 25 +- .../openvino/openvino_execution_provider.cc | 90 ++-- .../openvino/openvino_execution_provider.h | 9 +- .../openvino/openvino_provider_factory.cc | 17 +- .../core/providers/openvino/ov_bin_manager.cc | 440 ++++++++++++++++++ .../core/providers/openvino/ov_bin_manager.h | 77 +++ .../core/providers/openvino/ov_factory.cc | 2 +- .../core/providers/openvino/ov_interface.cc | 4 +- .../core/providers/openvino/ov_interface.h | 28 +- .../providers/openvino/ov_shared_context.cc | 145 ++++++ .../providers/openvino/ov_shared_context.h | 159 +++++++ .../qdq_transformations/qdq_stripping.cc | 28 +- .../qdq_transformations/qdq_stripping.h | 5 +- .../core/providers/openvino/weak_singleton.h | 40 ++ 21 files changed, 1126 insertions(+), 515 deletions(-) create mode 100644 onnxruntime/core/providers/openvino/ov_bin_manager.cc create mode 100644 onnxruntime/core/providers/openvino/ov_bin_manager.h create mode 100644 onnxruntime/core/providers/openvino/ov_shared_context.cc create mode 100644 onnxruntime/core/providers/openvino/ov_shared_context.h create mode 100644 onnxruntime/core/providers/openvino/weak_singleton.h diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 4a20847c0890c..abb5b31b76e44 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -37,19 +37,35 @@ ov::CompiledModel BackendManager::GetOVCompiledModel() { return ov::CompiledModel(); } +static bool ShouldExportEpContext(const SessionContext& session_context, const SubGraphContext& subgraph_context) { + return session_context.so_context_enable && (subgraph_context.is_ep_ctx_ovir_encapsulated || !subgraph_context.is_ep_ctx_graph); +} + BackendManager::BackendManager(SessionContext& session_context, - SharedContext& shared_context, + SharedContextManager& shared_context_manager, const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger, EPCtxHandler& ep_ctx_handle) : ep_ctx_handle_(ep_ctx_handle), session_context_(session_context), - shared_context_{shared_context} { + shared_context_manager_(shared_context_manager) { subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph); // If the graph contains a OVIR wrapped node, we check if it has matching xml file name attribute subgraph_context_.is_ep_ctx_ovir_encapsulated = ep_ctx_handle_.CheckEPCacheContextAttribute(subgraph, session_context_.onnx_model_path_name.filename().replace_extension("xml").string()); + if (subgraph_context_.is_ep_ctx_graph && !subgraph_context_.is_ep_ctx_ovir_encapsulated) { + shared_context_ = ep_ctx_handle.GetSharedContextForEpContextSubgraph(subgraph, + session_context_.GetModelPath()); + } else if (session_context_.so_context_enable && session_context_.so_share_ep_contexts) { + shared_context_ = shared_context_manager_.GetOrCreateActiveSharedContext(session_context_.GetOutputBinPath()); + } else { + // Creating a shared context to satisfy backend. It won't be used for weight sharing. + // Don't make it the active share context since we don't actually want to share it. + shared_context_ = shared_context_manager_.GetOrCreateSharedContext(session_context_.GetOutputBinPath()); + } + ORT_ENFORCE(shared_context_, "Could not create a shared context."); + subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) { // return empty if graph has no inputs or if types are not one of FP32/FP16 // else assume the type of the first input @@ -107,23 +123,6 @@ BackendManager::BackendManager(SessionContext& session_context, } std::string device_type = session_context_.device_type; - auto& sw = shared_context_.shared_weights; - if (session_context_.so_share_ep_contexts && !sw.metadata.empty()) { - std::filesystem::path weight_filename = session_context_.onnx_model_path_name.parent_path(); - if (sw.external_weight_filename.empty()) { - // Reasonable assumption that all metadata entries have the same external file location - sw.external_weight_filename = sw.metadata.begin()->second.location; - } - weight_filename /= sw.external_weight_filename; - std::ifstream weight_file(weight_filename); - - ORT_ENFORCE(weight_file, "Initializer file not found: ", weight_filename.string()); - if (!sw.mapped_weights) { - sw.mapped_weights = std::make_unique(weight_filename); - } - backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights); - } - if (subgraph_context_.has_dynamic_input_shape) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; if ((!session_context_.disable_dynamic_shapes && @@ -138,7 +137,7 @@ BackendManager::BackendManager(SessionContext& session_context, concrete_backend_ = BackendFactory::MakeBackend(model_proto, session_context_, subgraph_context_, - shared_context_, + *shared_context_, model_stream); } catch (std::string const& msg) { ORT_THROW(msg); @@ -162,7 +161,7 @@ BackendManager::BackendManager(SessionContext& session_context, concrete_backend_ = BackendFactory::MakeBackend(model_proto, session_context_, subgraph_context_, - shared_context_, + *shared_context_, model_stream); } catch (const OnnxRuntimeException& ex) { std::string exception_str = ex.what(); @@ -193,15 +192,15 @@ BackendManager::BackendManager(SessionContext& session_context, } } } - if (session_context_.so_context_enable && - (subgraph_context_.is_ep_ctx_ovir_encapsulated || !subgraph_context_.is_ep_ctx_graph)) { + + if (ShouldExportEpContext(session_context_, subgraph_context_)) { if (concrete_backend_) { - auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph); - if (!status.IsOK()) { - ORT_THROW(status); - } + shared_context_->AddNativeBlob(subgraph_context_.subgraph_name, concrete_backend_->GetOVCompiledModel()); } else { - ORT_THROW("[OpenVINO-EP] Cannot export compiled blob as EPCtx Node: Backend not initialized."); + ORT_THROW( + "Exporting dynamically compiled models at runtime is not supported. " + "Cannot export blobs of dynamic models that request static shape inference. " + "To export this model, set disable_dynamic_shapes to False"); } } } @@ -210,13 +209,9 @@ BackendManager::BackendManager(SessionContext& session_context, // precompiled blob is set. If that's the case: // By default, create model in embed mode where the blob stream is exported as data within // the EPContext node. -Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& graph_body_viewer) { - if (session_context_.disable_dynamic_shapes && subgraph_context_.has_dynamic_input_shape) { - std::string exception_str = - "Exporting dynamically compiled models at runtime is not supported. " - "Cannot export blobs of dynamic models that request static shape inference. " - "To export this model, set disable_dynamic_shapes to False"; - ORT_THROW(exception_str); +void BackendManager::TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& graph_body_viewer, bool include_embed_data) { + if (!ShouldExportEpContext(session_context_, subgraph_context_) || !concrete_backend_) { + return; } // If embed_mode, then pass on the serialized blob @@ -224,11 +219,10 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie std::string model_blob_str; auto compiled_model = concrete_backend_->GetOVCompiledModel(); if (session_context_.so_context_embed_mode) { // Internal blob - std::ostringstream model_blob_stream; - compiled_model.export_model(model_blob_stream); - model_blob_str = std::move(model_blob_stream).str(); - if (model_blob_str.empty()) { - ORT_THROW("Model blob stream is empty after exporting the compiled model."); + if (include_embed_data) { + std::stringstream ss; + shared_context_->Serialize(ss); + model_blob_str = std::move(ss).str(); } } else { // External blob // Build name by combining EpCtx model name (if available) and subgraph name. Model @@ -238,30 +232,17 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie name = graph_body_viewer.ModelPath().stem().string(); } ORT_ENFORCE(!name.empty()); - name += "_" + subgraph_context_.subgraph_name; - std::filesystem::path blob_filename = session_context_.so_context_file_path; - if (blob_filename.empty()) { - blob_filename = session_context_.onnx_model_path_name; - } - blob_filename = blob_filename.parent_path() / (name + ".blob"); - std::ofstream blob_file(blob_filename, - std::ios::out | std::ios::trunc | std::ios::binary); - if (!blob_file) { - std::ostringstream err_msg; - err_msg << "Unable to open file for epctx model dump: " << blob_filename; - ORT_THROW(err_msg.str()); - } - compiled_model.export_model(blob_file); - model_blob_str = blob_filename.filename().string(); + model_blob_str = shared_context_->GetBinPath().filename().string(); } - ORT_RETURN_IF_ERROR(ep_ctx_handle_.AddOVEPCtxNodeToGraph(graph_body_viewer, - subgraph_context_.subgraph_name, - session_context_.so_context_embed_mode, - std::move(model_blob_str))); - - return Status::OK(); + auto status = ep_ctx_handle_.AddOVEPCtxNodeToGraph(graph_body_viewer, + subgraph_context_.subgraph_name, + session_context_.so_context_embed_mode, + std::move(model_blob_str)); + if (!status.IsOK()) { + ORT_THROW("[OpenVINO-EP] Failed to add OVEP EPContext node to the graph: " + status.ErrorMessage()); + } } bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const { @@ -568,7 +549,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, if ((session_context_.device_type.find("NPU") != std::string::npos) && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; - Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights); + Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, *shared_context_); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); print_model_proto_duration(); @@ -835,7 +816,7 @@ void BackendManager::Compute(OrtKernelContext* context) { dynamic_backend = BackendFactory::MakeBackend(modelproto_with_concrete_shapes, session_context_, subgraph_context_, - shared_context_, + *shared_context_, model_stream); } catch (const OnnxRuntimeException& ex) { // Build option disables fallback to CPU on compilation failures with NPU. @@ -855,7 +836,7 @@ void BackendManager::Compute(OrtKernelContext* context) { dynamic_backend = BackendFactory::MakeBackend(modelproto_with_concrete_shapes, session_context_, subgraph_context_, - shared_context_, + *shared_context_, model_stream); } catch (std::string const& msg) { ORT_THROW(msg); diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index f091f95fe1c16..64dadb6c2151b 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -20,7 +20,7 @@ namespace openvino_ep { class BackendManager { public: BackendManager(SessionContext& session_context, - SharedContext& shared_context, + SharedContextManager& shared_context_manager, const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger, @@ -28,7 +28,7 @@ class BackendManager { void Compute(OrtKernelContext* context); void ShutdownBackendManager(); SessionContext& GetSessionContext(); - Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph); + void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data); ov::CompiledModel GetOVCompiledModel(); void RewindKVCache(size_t index); @@ -59,7 +59,8 @@ class BackendManager { SubGraphContext subgraph_context_; EPCtxHandler& ep_ctx_handle_; SessionContext& session_context_; - SharedContext& shared_context_; + SharedContextManager& shared_context_manager_; + std::shared_ptr shared_context_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 7201c47a805e3..45e518d16686e 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -20,130 +20,6 @@ using Exception = ov::Exception; namespace onnxruntime { namespace openvino_ep { -SharedContext::SharedWeights::WeightsFile::WeightsFile(std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary), file_path_(filename) { - try { - file_.exceptions(std::ifstream::failbit | std::ifstream::badbit); - weights_size_ = std::filesystem::file_size(filename); - } catch (const std::exception& e) { - ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what()); - } -} - -void SharedContext::SharedWeights::WeightsFile::load_weights(size_t file_offset, void* data, size_t size) { - ORT_ENFORCE(file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), "Error: File offset is out of bounds."); - file_.seekg(file_offset); - file_.read(reinterpret_cast(data), size); -} - -void* SharedContext::SharedWeights::WeightsFile::TryGetOrCreateDeviceMapping(std::optional& remote_context) { - std::string dev_name{}; - if (remote_context) { - dev_name = remote_context->get_device_name(); - } - - auto [it, inserted] = imported_device_tensors_.emplace(dev_name, MappingContainer{}); - if (inserted) { - if (dev_name == "NPU") { -#if OPENVINO_VERSION_AT_LEAST(2025, 3) - // try to import the memory mapped file to remote tensor - ORT_ENFORCE(remote_context, "Error: Remote context is required for NPU device."); - auto npu_context = remote_context->as(); - auto&& l0_tensor = npu_context.create_tensor(ov::element::Type_t::u8, {weights_size_}, ov::intel_npu::FileDescriptor(file_path_)); - it->second = MappingContainer{.ptr_ = l0_tensor.get(), .tensor_ = l0_tensor}; -#endif - } else if (dev_name.empty()) { - // CPU/virtual device case, create a CPU tensor memory mapped from file - auto&& mmaped_tensor = ov::read_tensor_data(file_path_); - it->second = MappingContainer{.ptr_ = mmaped_tensor.data(), .tensor_ = mmaped_tensor}; - } - } - - return it->second.ptr_; -} - -std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) { - try { - stream << metadata.size(); - - // Write each key-value pair - // Put elements in separate lines to facilitate reading - for (const auto& [key, value] : metadata) { - stream << std::endl - << key.name; - stream << std::endl - << value.location; - stream << std::endl - << value.data_offset; - stream << std::endl - << value.size; - stream << std::endl - << value.dimensions.size(); - for (const auto& dim : value.dimensions) { - stream << std::endl - << dim; - } - stream << std::endl - << value.element_type; - } - } catch (const Exception& e) { - ORT_THROW("Error: Failed to write map data.", e.what()); - } catch (...) { - ORT_THROW("Error: Failed to write map data."); - } - - ORT_ENFORCE(stream.good(), "Error: Failed to write map data."); - return stream; -} - -std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Metadata::Map& metadata) { - size_t map_size{0}; - try { - stream >> map_size; - - while (!stream.eof()) { - SharedContext::SharedWeights::Metadata::Key key; - SharedContext::SharedWeights::Metadata::Value value; - stream >> key.name; - stream >> value.location; - stream >> value.data_offset; - stream >> value.size; - size_t num_dimensions; - stream >> num_dimensions; - - if (stream.fail()) { - ORT_THROW("Error: Failed to read num_dimensions from stream."); - } - - constexpr size_t MAX_SAFE_DIMENSIONS = 1024; - - size_t safe_num_dimensions = num_dimensions; - - if (num_dimensions == 0 || safe_num_dimensions > MAX_SAFE_DIMENSIONS) { - ORT_THROW("Invalid number of dimensions provided."); - } - try { - value.dimensions.resize(safe_num_dimensions); - } catch (const std::bad_alloc&) { - ORT_THROW("Error: Memory allocation failed while resizing dimensions."); - } - - for (auto& dim : value.dimensions) { - stream >> dim; - } - stream >> value.element_type; - metadata.emplace(key, value); - } - } catch (const Exception& e) { - ORT_THROW("Error: Failed to read map data.", e.what()); - } catch (...) { - ORT_THROW("Error: Failed to read map data."); - } - - ORT_ENFORCE(metadata.size() == map_size, "Error: Inconsistent map data."); - - return stream; -} - namespace backend_utils { bool IsDebugEnabled() { @@ -390,96 +266,10 @@ void printPerformanceCounts(const std::vector& performanceMap, } void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName) { - auto performanceMap = request->GetNewObj().get_profiling_info(); + auto performanceMap = request->GetInfReq().get_profiling_info(); printPerformanceCounts(performanceMap, stream, std::move(deviceName)); } -ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt) { - static std::unordered_map map{ - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ov::element::f32}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT8, ov::element::u8}, - {ONNX_NAMESPACE::TensorProto_DataType_INT8, ov::element::i8}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT16, ov::element::u16}, - {ONNX_NAMESPACE::TensorProto_DataType_INT16, ov::element::i16}, - {ONNX_NAMESPACE::TensorProto_DataType_INT32, ov::element::i32}, - {ONNX_NAMESPACE::TensorProto_DataType_INT64, ov::element::i64}, - {ONNX_NAMESPACE::TensorProto_DataType_STRING, ov::element::string}, - {ONNX_NAMESPACE::TensorProto_DataType_BOOL, ov::element::boolean}, - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, ov::element::f16}, - {ONNX_NAMESPACE::TensorProto_DataType_DOUBLE, ov::element::f64}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT32, ov::element::u32}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT64, ov::element::u64}, - //{ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64, ov::element::undefined}, - //{ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128, ov::element::undefined}, - {ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16, ov::element::bf16}, - //{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN, ov::element::undefined}, - //{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ, ov::element::undefined}, - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2, ov::element::f8e5m2}, - //{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ, ov::element::undefined}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT4, ov::element::u4}, - {ONNX_NAMESPACE::TensorProto_DataType_INT4, ov::element::i4}, - }; - - if (auto result = map.find(dt); result != map.end()) { - return result->second; - } else { - throw std::runtime_error("Unsupported ONNX data type: " + std::to_string(dt)); - } -} - -// Function to handle tensor creation from external data -void CreateOVTensors(const std::string& device_name, - SharedContext::SharedWeights::Metadata::Map& metadata_map, - SharedContext::SharedWeights::WeightsFile& weights) { - // Get remote context if available - std::optional opt_remote_ctx; - try { - opt_remote_ctx = OVCore::Get()->core.get_default_context(device_name); - } catch (const std::exception&) { - // Remote context not available - } - - for (auto& [key, value] : metadata_map) { - if (value.tensor) continue; - - // Get element data type - auto onnx_element_type = (ONNX_NAMESPACE::TensorProto_DataType)value.element_type; - ov::element::Type ov_elementType = GetOpenVINOElementType(onnx_element_type); - - // Try to get memory-mapped weights - ov::Tensor tensor; - uint8_t* mmaped_weights = static_cast(weights.TryGetOrCreateDeviceMapping(opt_remote_ctx)); - - if (mmaped_weights) { - // We have memory mapped weights. Create a Tensor view into it for this value. - ORT_ENFORCE(value.data_offset < weights.Size() && - value.size <= weights.Size() && - (value.data_offset <= weights.Size() - value.size), - "File offset + size outside of external initializer file"); - void* mmapped_offset = static_cast(mmaped_weights + value.data_offset); - tensor = ov::Tensor(ov_elementType, value.dimensions, mmapped_offset); - } else { - ORT_ENFORCE(opt_remote_ctx, "Expected either memory-mapped weights or a valid remote context, but neither is available for device: ", device_name); - // Can't mmap the file to device tensor, create a host tensor and copy the data - tensor = opt_remote_ctx->create_host_tensor(ov_elementType, value.dimensions); - ORT_ENFORCE(tensor.get_byte_size() == value.size, "Remote tensor size mismatch"); - weights.load_weights(value.data_offset, tensor.data(), value.size); - } - - ORT_ENFORCE(tensor.get_byte_size() == value.size, "Unexpected tensor size mismatch"); - value.tensor = std::make_shared(std::move(tensor)); - } -} - -void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) { - for (auto& [key, value] : metadata_map) { - if (value.tensor) { - value.tensor.reset(); - } - } - metadata_map.clear(); -} - bool IsModelStreamXML(std::istream& model_stream) { std::streampos originalPos = model_stream.tellg(); diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index 27f791c7a5bd1..8ba35e0abd1bc 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -99,11 +99,6 @@ CreateOVModel(std::string&& model, const SessionContext& session_context, std::map>& const_outputs_map); -void CreateOVTensors(const std::string& device_name, - SharedContext::SharedWeights::Metadata::Map& metadata_map, - SharedContext::SharedWeights::WeightsFile& weights); -void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map); - void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index a950538c7c5fd..d7fc0553fb1d4 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -138,20 +138,13 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr } int num_infer_req = (session_context_.num_of_threads > 0) ? session_context_.num_of_threads : 1; std::function initializer = [](OVInferRequestPtr) {}; - auto metadata = shared_context_.shared_weights.metadata; if (session_context_.so_share_ep_contexts) { - initializer = [&metadata](OVInferRequestPtr ir_ptr) { - const auto input_count = ir_ptr->GetNumInputs(); - for (auto i = 0u; i < input_count; i++) { - using Key = SharedContext::SharedWeights::Metadata::Key; - const auto tensor_key = Key{ir_ptr->GetInputTensorName(i)}; - if (metadata.contains(tensor_key)) { - auto& value = metadata.at(tensor_key); - ir_ptr->SetTensor(tensor_key.name, value.tensor); - } - } + auto model_dir = session_context_.GetModelPath().parent_path(); + initializer = [this, model_dir = std::move(model_dir)](OVInferRequestPtr ir_ptr) { + shared_context_.SetSharedWeightsOnInferRequest(ir_ptr->GetInfReq(), model_dir); }; } + infer_req_pool_ = std::make_unique(exe_network_, num_infer_req, std::move(initializer)); bindings_ = std::make_unique(exe_network_, subgraph_context_, session_context_); } diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index edd9f176658f8..b14e05191dfaa 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -13,80 +13,14 @@ #include "core/common/common.h" #include "core/providers/openvino/ov_interface.h" #include "core/providers/shared_library/provider_api.h" +#include "ov_bin_manager.h" +#include "ov_shared_context.h" namespace onnxruntime { namespace openvino_ep { namespace fs = std::filesystem; -class SharedContext : public WeakSingleton { - // Keep the core alive as long as the shared SharedContext are alive. - std::shared_ptr OVCore_; - - public: - SharedContext() : OVCore_(OVCore::Get()) {} - struct SharedWeights { - struct Metadata { - struct Key { - std::string name; - bool operator==(const Key&) const = default; - }; - struct Hash { - std::size_t operator()(const Key& key) const noexcept { - return std::hash()(key.name); - } - }; - struct Value { - std::string location; - unsigned int data_offset; - unsigned int size; - std::vector dimensions; - std::int32_t element_type; - std::shared_ptr tensor; - }; - using Map = std::unordered_map; - friend std::ostream& operator<<(std::ostream& right, const Metadata::Map& metadata); - friend std::istream& operator>>(std::istream& right, Metadata::Map& metadata); - }; - - struct WeightsFile { - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightsFile); - WeightsFile() = delete; - explicit WeightsFile(std::filesystem::path filename); - - void load_weights(size_t file_offset, void* data, size_t size); - void* TryGetOrCreateDeviceMapping(std::optional& remote_context); - size_t Size() const { return weights_size_; } - - private: - std::ifstream file_; - std::filesystem::path file_path_; - size_t weights_size_; - struct MappingContainer { - void* ptr_{nullptr}; - ov::Tensor tensor_; - }; - std::map imported_device_tensors_; - }; - - void clear() { - metadata.clear(); - metadata_filepath.clear(); - external_weight_filename.clear(); - mapped_weights.reset(); - } - - fs::path external_weight_filename; - std::unique_ptr mapped_weights; - Metadata::Map metadata; - fs::path metadata_filepath; - } shared_weights; - - void clear() { - shared_weights.clear(); - } -}; - using config_t = std::map; using reshape_t = std::map; using layout_t = std::map; @@ -127,8 +61,8 @@ struct ProviderInfo { bool so_disable_cpu_ep_fallback{false}; // ORT session option bool so_context_embed_mode{false}; // ORT session option bool so_share_ep_contexts{false}; // ORT session option - fs::path so_context_file_path{}; // ORT session option bool so_stop_share_ep_contexts{false}; // ORT session option + fs::path so_context_file_path{}; // ORT session option const ConfigOptions* config_options{NULL}; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", @@ -156,8 +90,24 @@ struct SessionContext : ProviderInfo { mutable bool has_external_weights = false; // Value is set to mutable to modify from capability const std::vector OpenVINO_Version = {OPENVINO_VERSION_MAJOR, OPENVINO_VERSION_MINOR}; const std::string openvino_sdk_version = std::to_string(OPENVINO_VERSION_MAJOR) + "." + std::to_string(OPENVINO_VERSION_MINOR); + RuntimeConfig runtime_config; + const std::filesystem::path& GetModelPath() const { + return onnx_model_path_name.empty() ? so_context_file_path : onnx_model_path_name; + } + + const std::filesystem::path GetOutputBinPath() const { + std::filesystem::path bin_file_name = so_context_file_path; + if (bin_file_name.empty()) { + bin_file_name = onnx_model_path_name; + } + if (bin_file_name.empty()) { + return {}; + } + return BinManager::GetBinPathForModel(bin_file_name); + } + private: void InitRuntimeConfig() { if (config_options) { diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 051a39bd4f205..3260d18e9f43c 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -12,32 +12,11 @@ namespace onnxruntime { namespace openvino_ep { -EPCtxHandler::EPCtxHandler(std::string ov_sdk_version, const logging::Logger& logger) : openvino_sdk_version_(std::move(ov_sdk_version)), logger_(logger) { - epctx_model_ = Model::Create("ovep_context_model", false, logger_); -} - -/* Export the serialized blob string embedded onto an EPContext Node - * along with other metadata necessary to validate the graph on import - */ - -Status EPCtxHandler::ExportEPCtxModel(const std::string& model_name) { - // Serialize modelproto to string - auto model_proto = epctx_model_->ToProto(); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - // Finally, dump the model - std::ofstream epctx_onnx_model(model_name, - std::ios::out | std::ios::trunc | std::ios::binary); - if (!epctx_onnx_model) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to create epctx onnx model file"); - } +EPCtxHandler::EPCtxHandler(std::string ov_sdk_version, const logging::Logger& logger, std::shared_ptr shared_context_manager) + : openvino_sdk_version_(std::move(ov_sdk_version)), logger_(logger), shared_context_manager_(std::move(shared_context_manager)) { + ORT_ENFORCE(shared_context_manager_ != nullptr, "SharedContextManager pointer is null in EPCtxHandler constructor."); - if (!model_proto->SerializeToOstream(epctx_onnx_model)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to serialize model to file"); - } - LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Export blob as EPContext Node"; - - return Status::OK(); + epctx_model_ = Model::Create("ovep_context_model", false, logger_); } Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, @@ -59,7 +38,7 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, // Create EP context node attributes auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); - node_attributes->reserve(4); + node_attributes->reserve(6); { // Create EP context node attributes @@ -70,6 +49,13 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, embed_mode_attr->set_i(embed_mode); node_attributes->emplace(EMBED_MODE, std::move(*embed_mode_attr)); + // main context + auto main_graph_attr = ONNX_NAMESPACE::AttributeProto::Create(); + main_graph_attr->set_name(MAIN_CONTEXT); + main_graph_attr->set_type(onnx::AttributeProto_AttributeType_INT); + main_graph_attr->set_i(model_blob_str.empty() ? 0 : 1); + node_attributes->emplace(MAIN_CONTEXT, std::move(*main_graph_attr)); + // ep context auto ep_cache_context_attr = ONNX_NAMESPACE::AttributeProto::Create(); ep_cache_context_attr->set_name(EP_CACHE_CONTEXT); @@ -90,6 +76,13 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, source_attr->set_type(onnx::AttributeProto_AttributeType_STRING); source_attr->set_s(kOpenVINOExecutionProvider); node_attributes->emplace(SOURCE, std::move(*source_attr)); + + // partition name + auto partition_name_attr = ONNX_NAMESPACE::AttributeProto::Create(); + partition_name_attr->set_name(PARTITION_NAME); + partition_name_attr->set_type(onnx::AttributeProto_AttributeType_STRING); + partition_name_attr->set_s(graph_name); + node_attributes->emplace(PARTITION_NAME, std::move(*partition_name_attr)); } // Create EP context node @@ -100,8 +93,30 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, return Status::OK(); } -std::unique_ptr -EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { +std::shared_ptr EPCtxHandler::GetSharedContextForEpContextSubgraph(const GraphViewer& subgraph_view, const std::filesystem::path& ep_context_path) const { + if (!CheckForOVEPCtxNodeInGraph(subgraph_view)) { + return nullptr; + } + + auto first_index = *subgraph_view.GetNodesInTopologicalOrder().begin(); + auto node = subgraph_view.GetNode(first_index); + ORT_ENFORCE(node != nullptr); + auto& attrs = node->GetAttributes(); + ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) == 1); + const auto& ep_cache_context = attrs.at(EP_CACHE_CONTEXT).s(); + + ORT_ENFORCE(attrs.count(EMBED_MODE) == 1); + bool embed_mode = static_cast(attrs.at(EMBED_MODE).i()); + + std::filesystem::path bin_path{}; + if (!embed_mode) { + bin_path = ep_context_path.parent_path() / ep_cache_context; + } + + return shared_context_manager_->GetOrCreateSharedContext(bin_path); +} + +std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); auto node = graph_viewer.GetNode(first_index); ORT_ENFORCE(node != nullptr); @@ -130,16 +145,23 @@ EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_pa bool isXML = backend_utils::IsModelStreamXML(*result); std::filesystem::path native_blob_path{}; if (!isXML) { + ORT_ENFORCE(attrs.count(PARTITION_NAME) == 1, "Expected partition name for native ep context node"); + const auto& partition_name = attrs.at(PARTITION_NAME).s(); + // If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was // exported with must match the version that is currently running. native_blob_path = std::move(blob_filepath); ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_), "EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() + ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); + + result.reset(); // Release the stream as we will get the native blob from SharedContext + auto shared_context = shared_context_manager_->GetOrCreateSharedContext(native_blob_path); + return std::make_unique(shared_context->GetNativeBlobAsStream(partition_name), shared_context->GetNativeBlob(partition_name)); } LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; - return std::make_unique(std::move(result), native_blob_path); + return std::make_unique(std::move(result), ov::Tensor()); } bool EPCtxHandler::CheckForOVEPCtxNodeInGraph(const GraphViewer& graph_viewer) const { @@ -196,5 +218,61 @@ bool EPCtxHandler::CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, return false; } +void EPCtxHandler::Initialize(const std::vector& fused_nodes, const std::filesystem::path& ep_context_dir) { + bool has_embed_nodes = false; + bool has_non_embed_nodes = false; + bool has_main_context = false; + for (const auto& fused_node_graph : fused_nodes) { + const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; + + // Only process graphs that contain ep context nodes. + if (!CheckForOVEPCtxNodeInGraph(graph_viewer)) { + continue; + } + + auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); + const Node* node = graph_viewer.GetNode(first_index); + ORT_ENFORCE(node != nullptr, "Node pointer is null despite CheckForOVEPCtxNodeInGraph returning true"); + + auto& attrs = node->GetAttributes(); + ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) == 1, "EP_CACHE_CONTEXT attribute missing"); + + bool embed_mode = false; + if (attrs.count(EMBED_MODE) == 1) { + embed_mode = static_cast(attrs.at(EMBED_MODE).i()); + } + has_embed_nodes |= embed_mode; + has_non_embed_nodes |= !embed_mode; + + bool main_context = true; + if (attrs.count(MAIN_CONTEXT) == 1) { + main_context = static_cast(attrs.at(MAIN_CONTEXT).i()); + } + has_main_context |= main_context; + + const std::string& ep_cache_context = attrs.at(EP_CACHE_CONTEXT).s(); + if (embed_mode) { + std::filesystem::path dummy_path{}; + auto shared_context = shared_context_manager_->GetOrCreateSharedContext(dummy_path); + if (main_context) { + ORT_ENFORCE(!ep_cache_context.empty(), "Embedded EP context is indicated but EP_CACHE_CONTEXT attribute is empty."); + std::istringstream ss(ep_cache_context); + shared_context->Deserialize(ss); + } + } else { + std::filesystem::path ep_context_path = ep_context_dir / ep_cache_context; + if (ep_context_path.extension() != ".xml") { + auto shared_context = shared_context_manager_->GetOrCreateSharedContext(ep_context_path); + shared_context->Deserialize(); + } + } + } + + ORT_ENFORCE(!(has_embed_nodes && has_non_embed_nodes), + "Mixed embed and non-embed EP context nodes are not supported in a single model."); + ORT_ENFORCE(!(has_embed_nodes && !has_main_context), + "Expected at least one main context node when embedded EP context nodes are present."); +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index f207f5014ca1f..fc2a56c1d0671 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -8,43 +8,52 @@ #include #include "core/providers/shared_library/provider_api.h" +#include "core/framework/execution_provider.h" +#include "ov_bin_manager.h" +#include "ov_shared_context.h" namespace onnxruntime { namespace openvino_ep { +class SharedBinManager; + struct ModelBlobWrapper { - ModelBlobWrapper(std::unique_ptr stream, const std::filesystem::path& native_blob_path) : stream_(std::move(stream)), maybe_native_blob_path_(native_blob_path) {} + ModelBlobWrapper(std::unique_ptr stream, const ov::Tensor& tensor) : stream_(std::move(stream)), tensor_(tensor) {} std::unique_ptr stream_; - std::filesystem::path maybe_native_blob_path_; + ov::Tensor tensor_; // May be empty if model blob is provided as stream only. }; // Utilities to handle EPContext node export and parsing of an EPContext node // to create the compiled_model object to infer on static const char EPCONTEXT_OP[] = "EPContext"; static const char EMBED_MODE[] = "embed_mode"; +static const char MAIN_CONTEXT[] = "main_context"; +static const char PARTITION_NAME[] = "partition_name"; static const char EP_CACHE_CONTEXT[] = "ep_cache_context"; static const char EP_SDK_VER[] = "ep_sdk_version"; static const char SOURCE[] = "source"; class EPCtxHandler { public: - EPCtxHandler(std::string ov_sdk_version, const logging::Logger& logger); + EPCtxHandler(std::string ov_sdk_version, const logging::Logger& logger, std::shared_ptr shared_context_manager); EPCtxHandler(const EPCtxHandler&) = delete; // No copy constructor - Status ExportEPCtxModel(const std::string& model_name); - bool CheckForOVEPCtxNodeInGraph(const GraphViewer& graph_viewer) const; + bool CheckForOVEPCtxNodeInGraph(const GraphViewer& subgraph_view) const; + std::shared_ptr GetSharedContextForEpContextSubgraph(const GraphViewer& subgraph_view, const std::filesystem::path& ep_context_path) const; bool CheckForOVEPCtxNode(const Node& node) const; - Status AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, + Status AddOVEPCtxNodeToGraph(const GraphViewer& subgraph_view, const std::string& graph_name, const bool embed_mode, std::string&& model_blob_str) const; - std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const; + std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& subgraph_view) const; InlinedVector GetEPCtxNodes() const; - bool CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const; + bool CheckEPCacheContextAttribute(const GraphViewer& subgraph_view, const std::string& target_attr_extn) const; + void Initialize(const std::vector& fused_nodes, const std::filesystem::path& ep_context_path); private: const std::string openvino_sdk_version_; std::unique_ptr epctx_model_; const logging::Logger& logger_; + std::shared_ptr shared_context_manager_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 049af81c9ffb2..f9c9fa2ea6f48 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -17,6 +17,7 @@ #ifdef USE_OVEP_NPU_MEMORY #include "core/providers/openvino/ov_allocator.h" #endif +#include "ov_interface.h" namespace onnxruntime { namespace openvino_ep { @@ -54,11 +55,12 @@ static std::vector parseDevices(const std::string& device_string, } #endif -OpenVINOExecutionProvider::OpenVINOExecutionProvider(const ProviderInfo& info, std::shared_ptr shared_context) +OpenVINOExecutionProvider::OpenVINOExecutionProvider(const ProviderInfo& info) : IExecutionProvider{onnxruntime::kOpenVINOExecutionProvider}, session_context_(info), - shared_context_{std::move(shared_context)}, - ep_ctx_handle_{session_context_.openvino_sdk_version, *GetLogger()} { + ov_core_(OVCore::Get()), + shared_context_manager_(SharedContextManager::Get()), + ep_ctx_handle_{session_context_.openvino_sdk_version, *GetLogger(), shared_context_manager_} { InitProviderOrtApi(); #ifdef _WIN32 session_id_ = global_session_counter_.fetch_add(1) + 1; @@ -72,7 +74,6 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { backend_manager.ShutdownBackendManager(); } backend_managers_.clear(); - shared_context_.reset(); } std::vector> @@ -102,6 +103,11 @@ common::Status OpenVINOExecutionProvider::Compile( auto& logger = *GetLogger(); Status status = Status::OK(); + if (session_context_.so_context_enable && session_context_.so_context_embed_mode && session_context_.so_share_ep_contexts) { + return Status(common::StatusCategory::ONNXRUNTIME, common::EP_FAIL, + std::string("Invalid EP context configuration: ") + kOrtSessionOptionEpContextEmbedMode + " must be 0 if " + kOrtSessionOptionShareEpContexts + " is 1."); + } + bool is_epctx_model = false; if (!fused_nodes.empty()) { // Assume these properties are constant for all the model subgraphs, otherwise move to SubGraphContext @@ -115,24 +121,8 @@ common::Status OpenVINOExecutionProvider::Compile( is_epctx_model = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(graph_body_viewer_0); } - // The block below is executed during EP context model inference - auto& metadata = shared_context_->shared_weights.metadata; // Metadata object in memory - if (session_context_.so_share_ep_contexts && - is_epctx_model && - metadata.empty()) { - fs::path context_model_file_path = session_context_.so_context_file_path; - if (context_model_file_path.empty()) { - // If ep.context_file_path is not set the input model path is used - context_model_file_path = session_context_.onnx_model_path_name; - } - - // Metadata is always read from model location, this could be a source or epctx model - fs::path metadata_filename = context_model_file_path.stem().string() + "_metadata.bin"; - fs::path metadata_file_path = context_model_file_path.parent_path() / metadata_filename; - std::ifstream file(metadata_file_path, std::ios::binary); - ORT_RETURN_IF_NOT(file, "Metadata file was not found: " + metadata_file_path.string()); - shared_context_->shared_weights.metadata_filepath = std::move(metadata_file_path); - file >> metadata; + if (is_epctx_model) { + ep_ctx_handle_.Initialize(fused_nodes, session_context_.GetOutputBinPath().parent_path()); } struct OpenVINOEPFunctionState { @@ -153,12 +143,11 @@ common::Status OpenVINOExecutionProvider::Compile( // For original model, check if the user wants to export a model with pre-compiled blob auto& backend_manager = backend_managers_.emplace_back(session_context_, - *shared_context_, + *shared_context_manager_, fused_node, graph_body_viewer, logger, ep_ctx_handle_); - compute_info.create_state_func = [&backend_manager](ComputeContext* context, FunctionState* state) { OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState{ @@ -189,42 +178,31 @@ common::Status OpenVINOExecutionProvider::Compile( }; node_compute_funcs.push_back(std::move(compute_info)); - - if (!status.IsOK()) { - break; - } } - // The block below is executed during EP context model generation - if (session_context_.so_context_enable && - session_context_.so_share_ep_contexts && - !metadata.empty()) { - // For models after the first the metadata name comes from the shared context - fs::path metadata_file_path = shared_context_->shared_weights.metadata_filepath; - if (metadata_file_path.empty()) { - metadata_file_path = session_context_.so_context_file_path; - std::string name_append{"_metadata.bin"}; - if (metadata_file_path.empty()) { - metadata_file_path = session_context_.onnx_model_path_name; - name_append = "_ctx" + name_append; - } - auto metadata_filename = metadata_file_path.stem().string() + name_append; - metadata_file_path.replace_filename(metadata_filename); - shared_context_->shared_weights.metadata_filepath = metadata_file_path; - } + // Export compiled blobs as EPContext nodes if context enable is set + if (session_context_.so_context_enable) { + auto backend_it = backend_managers_.begin(); + bool is_first = true; - // Metadata is generated only for shared contexts - // If saving metadata then save it to the provided path or use the original model path - // Multiple calls to Compile() will update the metadata and for the last call - // the resulting file will contain the aggregated content - std::ofstream file{metadata_file_path, std::ios::binary}; - ORT_RETURN_IF_NOT(file, "Metadata file could not be written: ", metadata_file_path); - file << metadata; - } + for (const auto& fused_node_graph : fused_nodes) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + + // Set include_embed_data to true only for the first backend manager + backend_it->TryExportCompiledBlobAsEPCtxNode(graph_body_viewer, is_first); - if (session_context_.so_stop_share_ep_contexts) { - if (shared_context_) { - shared_context_->clear(); + is_first = false; + ++backend_it; + } + + // bit clunky ideally we should try to fold this into ep context handler + if (!session_context_.so_context_embed_mode) { + auto shared_context = shared_context_manager_->GetOrCreateActiveSharedContext(session_context_.GetOutputBinPath()); + shared_context->Serialize(); + if (session_context_.so_stop_share_ep_contexts) { + shared_context_manager_->ClearActiveSharedContext(); + shared_context->Clear(); + } } } diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index a375a9ee788bd..326f6de30498f 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -15,6 +15,9 @@ #include "core/providers/openvino/backend_manager.h" #include "core/providers/openvino/contexts.h" +#include "ov_shared_context.h" +#include "ov_bin_manager.h" +#include "ov_interface.h" #ifdef _WIN32 #include "core/providers/openvino/ov_tracing.h" @@ -50,7 +53,7 @@ static std::vector split(const std::string& s, char delim) { // Logical device representation. class OpenVINOExecutionProvider : public IExecutionProvider { public: - explicit OpenVINOExecutionProvider(const ProviderInfo& info, std::shared_ptr shared_context); + explicit OpenVINOExecutionProvider(const ProviderInfo& info); ~OpenVINOExecutionProvider(); std::vector> @@ -76,7 +79,9 @@ class OpenVINOExecutionProvider : public IExecutionProvider { #endif private: SessionContext session_context_; - std::shared_ptr shared_context_; + std::shared_ptr ov_core_; + std::shared_ptr shared_context_manager_; + std::list backend_managers_; // EP session owns the backend objects EPCtxHandler ep_ctx_handle_; diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 298eb25713bec..cb94fb3793024 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -16,6 +16,7 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "nlohmann/json.hpp" #include "core/providers/openvino/openvino_parser_utils.h" +#include "ov_interface.h" namespace onnxruntime { namespace openvino_ep { @@ -381,14 +382,14 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, } struct OpenVINOProviderFactory : IExecutionProviderFactory { - OpenVINOProviderFactory(ProviderInfo provider_info, std::shared_ptr shared_context) - : provider_info_(std::move(provider_info)), shared_context_(std::move(shared_context)) {} + OpenVINOProviderFactory(ProviderInfo provider_info, std::shared_ptr ov_core) + : provider_info_(std::move(provider_info)), ov_core_(ov_core) {} ~OpenVINOProviderFactory() override {} std::unique_ptr CreateProvider() override { ParseConfigOptions(provider_info_); - return std::make_unique(provider_info_, shared_context_); + return std::make_unique(provider_info_); } // Called by InferenceSession when registering EPs. Allows creation of an EP instance that is initialized with @@ -421,7 +422,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { ParseProviderInfo(provider_options, &config_options, provider_info); ParseConfigOptions(provider_info); - auto ov_ep = std::make_unique(provider_info, shared_context_); + auto ov_ep = std::make_unique(provider_info); ov_ep->SetLogger(reinterpret_cast(&session_logger)); return ov_ep; } @@ -432,14 +433,14 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { std::unique_ptr CreateProvider_V2(const OrtSessionOptions& /*session_options*/, const OrtLogger& session_logger) { ProviderInfo provider_info = provider_info_; - auto ov_ep = std::make_unique(provider_info, shared_context_); + auto ov_ep = std::make_unique(provider_info); ov_ep->SetLogger(reinterpret_cast(&session_logger)); return ov_ep; } private: ProviderInfo provider_info_; - std::shared_ptr shared_context_; + std::shared_ptr ov_core_; }; struct ProviderInfo_OpenVINO_Impl : ProviderInfo_OpenVINO { @@ -464,7 +465,7 @@ struct OpenVINO_Provider : Provider { ProviderInfo pi; ParseProviderInfo(provider_options, config_options, pi); - return std::make_shared(pi, SharedContext::Get()); + return std::make_shared(pi, OVCore::Get()); } Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, @@ -571,7 +572,7 @@ struct OpenVINO_Provider : Provider { ParseConfigOptions(pi); // Create and return the execution provider - auto factory = std::make_unique(pi, SharedContext::Get()); + auto factory = std::make_unique(pi, OVCore::Get()); ep = factory->CreateProvider_V2(session_options, logger); return Status::OK(); } diff --git a/onnxruntime/core/providers/openvino/ov_bin_manager.cc b/onnxruntime/core/providers/openvino/ov_bin_manager.cc new file mode 100644 index 0000000000000..bdab631bb478b --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_bin_manager.cc @@ -0,0 +1,440 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "ov_bin_manager.h" +#include "ov_shared_context.h" +#include +#include "core/providers/shared_library/provider_api.h" // for ORT_VERSION and kOpenVINOExecutionProvider + +namespace onnxruntime { +namespace openvino_ep { + +static inline uint64_t AlignUp(uint64_t value, uint64_t alignment) { + return (value + alignment - 1) / alignment * alignment; +} + +// Custom streambuf that wraps an ov::Tensor's memory +// Provides us a std::istream interface over the tensor data without copying. +// Only supports input operations. +class TensorStreamBuf : public std::streambuf { + public: + explicit TensorStreamBuf(ov::Tensor& tensor) { + char* data = const_cast(tensor.data()); + size_t size = tensor.get_byte_size(); + setg(data, data, data + size); + } + + protected: + // Override seekoff for proper seeking support + std::streampos seekoff(std::streamoff off, std::ios_base::seekdir dir, std::ios_base::openmode which) override { + if (which & std::ios_base::in) { + char* new_pos = nullptr; + switch (dir) { + case std::ios_base::beg: + new_pos = eback() + off; + break; + case std::ios_base::cur: + new_pos = gptr() + off; + break; + case std::ios_base::end: + new_pos = egptr() + off; + break; + default: + return std::streampos(std::streamoff(-1)); + } + + if (new_pos >= eback() && new_pos <= egptr()) { + setg(eback(), new_pos, egptr()); + return std::streampos(new_pos - eback()); + } + } + return std::streampos(std::streamoff(-1)); + } + + // Override seekpos for proper seeking support + std::streampos seekpos(std::streampos pos, std::ios_base::openmode which) override { + return seekoff(std::streamoff(pos), std::ios_base::beg, which); + } +}; + +// Custom istream that owns the tensor to ensure proper lifetime management +class TensorStream : public std::istream { + public: + explicit TensorStream(ov::Tensor tensor) + : std::istream(&buf_), + tensor_(std::move(tensor)), + buf_(tensor_) {} + + private: + ov::Tensor tensor_; // Keep tensor alive + TensorStreamBuf buf_; // Buffer wrapping tensor data +}; + +/* + Logical layout of the single binary file: + [Header] + [BSON Metadata] ← Contains blob_metadata_map with data_offset and size for each blob + [Padding to 64K alignment] ← Blob section starts here (64K aligned) + [Blob 1] ← BSON blob_metadata_map["blob_name"].data_offset points here + [Padding to 64K alignment] ← Each blob end is 64K aligned + [Blob 2] ← BSON blob_metadata_map["blob_name2"].data_offset points here + [Padding to 64K alignment] + [Blob 3] ← BSON blob_metadata_map["blob_name3"].data_offset points here + ... + + BSON Schema: + { + "version": , // BSON schema version (semver format) + "producer": , // Producer identifier (e.g., "onnxruntime-openvino-ep-plugin") + "weights_metadata_map": { // Map of ONNX tensor names to external weight file metadata + "": { + "location": , // Relative path to external weights file + "data_offset": , // Offset within external weights file + "size": // Size of weight data in bytes + }, + ... + }, + "blob_metadata_map": { // Map of blob names to compiled model blob metadata + "": { + "data_offset": , // Absolute file offset to blob data (64K aligned) + "size": // Actual blob data size (excluding padding) + }, + ... + } + } + + Note: data_offset values in blob_metadata_map are absolute file positions. + size values exclude alignment padding bytes. +*/ + +// "OVEP_BIN" in little-endian (memory will read as 'O','V','E','P','_','B','I','N') +constexpr uint64_t kMagicNumber = 0x4E49425F5045564FULL; + +enum class BinVersion : uint64_t { + v1 = 1, + current = v1 +}; + +struct header_t { + uint64_t magic; + uint64_t version; + uint64_t header_size; + uint64_t bson_start_offset; + uint64_t bson_size; +}; + +constexpr uint64_t kBlobAlignment = 64 * 1024; + +// BSON field names +namespace BSONFields { +constexpr const char* kVersion = "version"; +constexpr const char* kProducer = "producer"; +constexpr const char* kWeightsMetadata = "weights_metadata_map"; +constexpr const char* kBlobMetadata = "blob_metadata_map"; +constexpr const char* kLocation = "location"; +constexpr const char* kDataOffset = "data_offset"; +constexpr const char* kSize = "size"; +constexpr const char* kCurrentBsonVersion = "1.0.0"; +constexpr const char* kProducerName = "onnxruntime-openvino-ep-" ORT_VERSION; +} // namespace BSONFields + +template +constexpr std::underlying_type_t to_underlying(E e) noexcept { + static_assert(std::is_enum_v, "to_underlying requires an enum type"); + return static_cast>(e); +} + +void BinManager::AddNativeBlob(const std::string& name, const ov::CompiledModel& compiled_model) { + std::unique_lock lock(mutex_); + native_blobs_[name] = BlobContainer{.compiled_model = compiled_model, .tensor = {}, .data = {}, .serialized_info = {0, 0}}; +} + +ov::Tensor BinManager::GetNativeBlob(const std::string& blob_name) { + std::unique_lock lock(mutex_); + + auto it = native_blobs_.find(blob_name); + ORT_ENFORCE(it != native_blobs_.end(), "Blob not found for ", blob_name); + + auto& blob_container = it->second; + if (blob_container.tensor) { + return blob_container.tensor; + } + + ORT_ENFORCE(blob_container.serialized_info.size > 0 || !blob_container.data.empty(), + "Blob has no serialization info or embedded data for ", blob_name); + + if (!external_bin_path_.value_or("").empty() && !mapped_bin_) { + // Use ov::read_tensor_data to create a memory-mapped tensor from external file + mapped_bin_ = ov::read_tensor_data(external_bin_path_.value()); + } + + if (mapped_bin_) { + // Create a tensor from memory-mapped external file + blob_container.tensor = ov::Tensor( + ov::element::u8, + ov::Shape{blob_container.serialized_info.size}, + mapped_bin_.data() + blob_container.serialized_info.file_offset); + } else { + // Create a tensor from embedded data vector + blob_container.tensor = ov::Tensor( + ov::element::u8, + ov::Shape{blob_container.data.size()}, + blob_container.data.data()); + } + + return blob_container.tensor; +} + +std::unique_ptr BinManager::GetNativeBlobAsStream(const std::string& blob_name) { + return std::make_unique(GetNativeBlob(blob_name)); +} + +void BinManager::Clear() { + std::unique_lock lock(mutex_); + native_blobs_.clear(); + mapped_bin_ = {}; + external_bin_path_.reset(); +} + +std::filesystem::path BinManager::GetBinPathForModel(const std::filesystem::path& model_path) { + ORT_ENFORCE(!model_path.empty()); + return model_path.parent_path() / (model_path.stem().string() + "_" + kOpenVINOExecutionProvider + ".bin"); +} + +void BinManager::Serialize(std::shared_ptr shared_context) { + auto path = GetExternalBinPath(); + std::ofstream stream(path, std::ios::out | std::ios::binary); + ORT_ENFORCE(stream.is_open(), "Failed to open file for serialization: " + path.string()); + Serialize(stream, shared_context); +} + +void BinManager::Deserialize(std::shared_ptr shared_context) { + auto path = GetExternalBinPath(); + std::ifstream stream(path, std::ios::in | std::ios::binary); + ORT_ENFORCE(stream.is_open(), "Failed to open file for deserialization: " + path.string()); + Deserialize(stream, shared_context); +} + +bool BinManager::ShouldSerialize(const std::shared_ptr& shared_context) const { + if (shared_context) { + auto metadata = shared_context->GetMetadataCopy(); + if (!metadata.empty()) { + return true; + } + } + return !native_blobs_.empty(); +} + +void BinManager::Serialize(std::ostream& stream, std::shared_ptr shared_context) { + std::shared_lock ul(mutex_); + + if (!ShouldSerialize(shared_context)) { + // nothing to serialize + return; + } + + const auto stream_start = stream.tellp(); + + auto write_alignment_padding = [&stream](uint64_t current_pos, uint64_t alignment) { + uint64_t aligned_position = AlignUp(current_pos, alignment); + uint64_t padding_size = aligned_position - current_pos; + if (padding_size > 0) { + std::vector padding(padding_size, 0); + stream.write(padding.data(), padding.size()); + ORT_ENFORCE(stream.good(), "Error: Failed to write alignment padding."); + } + }; + + // Reserve space for header (will be updated later) + header_t header{}; + header.magic = kMagicNumber; + header.version = to_underlying(BinVersion::current); + header.header_size = sizeof(header_t); + stream.write(reinterpret_cast(&header), sizeof(header)); + ORT_ENFORCE(stream.good(), "Error: Failed to write header."); + + // Build JSON metadata + nlohmann::json j; + j[BSONFields::kVersion] = BSONFields::kCurrentBsonVersion; + j[BSONFields::kProducer] = BSONFields::kProducerName; + + // Add weights metadata as a map (from SharedContext if available) + if (shared_context) { + auto metadata = shared_context->GetMetadataCopy(); + if (!metadata.empty()) { + nlohmann::json weights_map = nlohmann::json::object(); + for (const auto& [key, value] : metadata) { + nlohmann::json weight_entry; + weight_entry[BSONFields::kLocation] = value.serialized.location.string(); + weight_entry[BSONFields::kDataOffset] = value.serialized.data_offset; + weight_entry[BSONFields::kSize] = value.serialized.size; + weights_map[key] = weight_entry; + } + j[BSONFields::kWeightsMetadata] = weights_map; + } + } + + // Add blob metadata with placeholder values as a map (will be updated after writing blobs) + nlohmann::json blob_map = nlohmann::json::object(); + for (const auto& [key, value] : native_blobs_) { + nlohmann::json blob_entry; + auto max_val = std::numeric_limits::max(); + // Placehold max size since we don't know actual offsets/sizes yet, and if they aren't max they might serialize smaller. + blob_entry[BSONFields::kDataOffset] = max_val; + blob_entry[BSONFields::kSize] = max_val; + blob_map[key] = blob_entry; + } + j[BSONFields::kBlobMetadata] = blob_map; + + // Write BSON metadata (will be rewritten later with correct blob info) + header.bson_start_offset = stream.tellp(); + + size_t orig_bson_size; + { + std::vector bson_data = nlohmann::json::to_bson(j); + orig_bson_size = bson_data.size(); + stream.write(reinterpret_cast(bson_data.data()), bson_data.size()); + ORT_ENFORCE(stream.good(), "Error: Failed to write BSON data."); + } + uint64_t bson_end = stream.tellp(); + + write_alignment_padding(bson_end, kBlobAlignment); + + // Write blob data and capture actual offsets/sizes + for (auto& [blob_name, value] : native_blobs_) { + uint64_t blob_start = stream.tellp(); + value.compiled_model.export_model(stream); + ORT_ENFORCE(stream.good(), "Error: Failed to write blob data for ", blob_name); + // Seek to end of stream after writing in case export model didn't leave us there + stream.seekp(0, std::ios::end); + uint64_t blob_end = stream.tellp(); + uint64_t blob_size = blob_end - blob_start; + + // Update the BlobContainer + BSON with serialization info + value.serialized_info.file_offset = blob_start; + value.serialized_info.size = blob_size; + j[BSONFields::kBlobMetadata][blob_name][BSONFields::kDataOffset] = blob_start; + j[BSONFields::kBlobMetadata][blob_name][BSONFields::kSize] = blob_size; + + write_alignment_padding(blob_end, kBlobAlignment); + } + + // Rewrite BSON metadata with correct blob info + std::vector updated_bson_data = nlohmann::json::to_bson(j); + ORT_ENFORCE(updated_bson_data.size() <= orig_bson_size, + "Error: BSON size larger after updating blob info. Original: ", orig_bson_size, + " Updated: ", updated_bson_data.size()); + + stream.seekp(header.bson_start_offset); + stream.write(reinterpret_cast(updated_bson_data.data()), updated_bson_data.size()); + ORT_ENFORCE(stream.good(), "Error: Failed to rewrite BSON data."); + bson_end = stream.tellp(); + header.bson_size = bson_end - header.bson_start_offset; + + // Update header with BSON offsets + stream.seekp(stream_start); + stream.write(reinterpret_cast(&header), sizeof(header)); + ORT_ENFORCE(stream.good(), "Error: Failed to update header."); + + stream.seekp(0, std::ios::end); // Move to end after writing. +} + +void BinManager::Deserialize(std::istream& stream, std::shared_ptr shared_context) { + // Read and validate header + header_t header{}; + + stream.read(reinterpret_cast(&header), sizeof(header)); + ORT_ENFORCE(stream.good(), "Error: Failed to read header."); + ORT_ENFORCE(header.magic == kMagicNumber, "Error: Invalid magic number. Expected: 0x", std::hex, kMagicNumber, " Got: 0x", header.magic); + ORT_ENFORCE(header.version == to_underlying(BinVersion::current), "Error: Unsupported file version: ", header.version); + ORT_ENFORCE(header.header_size == sizeof(header_t), "Error: Header size mismatch."); + + // Seek to BSON metadata and read it + stream.seekg(header.bson_start_offset); + ORT_ENFORCE(stream.good(), "Error: Failed to seek to BSON metadata."); + + // Parse BSON + nlohmann::json j; + { + std::vector bson_data(header.bson_size); + stream.read(reinterpret_cast(bson_data.data()), header.bson_size); + j = nlohmann::json::from_bson(bson_data); + } + + // Validate BSON version (check major version compatibility) + ORT_ENFORCE(j.contains(BSONFields::kVersion), "Error: Missing version in BSON metadata."); + auto bson_version = j[BSONFields::kVersion].get(); + + // Extract major version from semver strings (format: "major.minor.patch") + auto get_major_version = [](const std::string& version) -> int { + size_t dot_pos = version.find('.'); + if (dot_pos == std::string::npos) return -1; + try { + return std::stoi(version.substr(0, dot_pos)); + } catch (...) { + return -1; + } + }; + + int file_major = get_major_version(bson_version); + int current_major = get_major_version(BSONFields::kCurrentBsonVersion); + + ORT_ENFORCE(file_major >= 0 && current_major >= 0, + "Error: Invalid BSON version format. Expected: ", BSONFields::kCurrentBsonVersion, + " Got: ", bson_version); + ORT_ENFORCE(file_major == current_major, + "Error: Incompatible BSON schema major version. Expected: ", current_major, + " Got: ", file_major, " (full version: ", bson_version, ")"); + + // Parse weights metadata and populate SharedContext if available + if (j.contains(BSONFields::kWeightsMetadata)) { + ORT_ENFORCE(shared_context, "Error: Bin contains shared weights metadata but no SharedContext was provided during deserialization."); + const auto& weights_map = j[BSONFields::kWeightsMetadata]; + if (weights_map.is_object()) { + for (const auto& [weight_name, weight_entry] : weights_map.items()) { + auto location = weight_entry[BSONFields::kLocation].get(); + auto data_offset = weight_entry[BSONFields::kDataOffset].get(); + auto size = weight_entry[BSONFields::kSize].get(); + shared_context->AddExternalWeight(weight_name, data_offset, size, location); + } + } + } + + // Parse blob metadata + ORT_ENFORCE(j.contains(BSONFields::kBlobMetadata), "Error: Missing blob metadata in BSON."); + const auto& blob_map = j[BSONFields::kBlobMetadata]; + ORT_ENFORCE(blob_map.is_object(), "Error: Blob metadata must be an object."); + + // Determine if we're deserializing from an external file or embedded stream + const bool has_external_file = !external_bin_path_.value_or("").empty(); + + std::unique_lock lock(mutex_); + for (const auto& [blob_name, blob_entry] : blob_map.items()) { + uint64_t blob_offset = blob_entry[BSONFields::kDataOffset].get(); + uint64_t blob_size = blob_entry[BSONFields::kSize].get(); + + BlobContainer container; + container.serialized_info.file_offset = blob_offset; + container.serialized_info.size = blob_size; + + // If no external file, extract blob data into vector + if (!has_external_file) { + // Seek to blob offset and read data into vector + auto current_pos = stream.tellg(); + stream.seekg(blob_offset); + ORT_ENFORCE(stream.good(), "Error: Failed to seek to blob data for ", blob_name); + + container.data.resize(blob_size); + stream.read(reinterpret_cast(container.data.data()), blob_size); + ORT_ENFORCE(stream.good(), "Error: Failed to read blob data for ", blob_name); + + // Restore stream position + stream.seekg(current_pos); + } + + native_blobs_[blob_name] = std::move(container); + } +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_bin_manager.h b/onnxruntime/core/providers/openvino/ov_bin_manager.h new file mode 100644 index 0000000000000..d6d6ada2d252a --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_bin_manager.h @@ -0,0 +1,77 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/runtime/core.hpp" +#include "weak_singleton.h" + +namespace onnxruntime { +namespace openvino_ep { + +// Forward declaration +class SharedContext; + +// Manages native compiled model blobs and binary file serialization/deserialization +class BinManager { + public: + BinManager() = default; + BinManager(const std::filesystem::path& external_bin_path) : external_bin_path_(external_bin_path) {} + ~BinManager() = default; + + // Blob management + void AddNativeBlob(const std::string& name, const ov::CompiledModel& compiled_model); + ov::Tensor GetNativeBlob(const std::string& blob_name); + std::unique_ptr GetNativeBlobAsStream(const std::string& blob_name); + void Clear(); + + // Serialization/Deserialization + void Serialize(std::ostream& stream, std::shared_ptr shared_context = nullptr); + void Deserialize(std::istream& stream, std::shared_ptr shared_context = nullptr); + + void Serialize(std::shared_ptr shared_context = nullptr); + void Deserialize(std::shared_ptr shared_context = nullptr); + + // Path management + void TrySetExternalBinPath(const std::filesystem::path& bin_path) { + std::unique_lock lock(mutex_); + if (!external_bin_path_) { + external_bin_path_ = bin_path; + } + } + std::filesystem::path GetExternalBinPath() const { + std::shared_lock lock(mutex_); + return external_bin_path_.value_or(""); + } + + static std::filesystem::path GetBinPathForModel(const std::filesystem::path& model_path); + + private: + struct BlobContainer { + ov::CompiledModel compiled_model; + ov::Tensor tensor; + std::vector data; // For embedded blobs when no external file exists + struct { + uint64_t file_offset{0}; + uint64_t size{0}; + } serialized_info; + }; + + bool ShouldSerialize(const std::shared_ptr& shared_context) const; + + mutable std::shared_mutex mutex_; + std::optional external_bin_path_; + ov::Tensor mapped_bin_; + std::unordered_map native_blobs_; +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_factory.cc b/onnxruntime/core/providers/openvino/ov_factory.cc index 2853cc17726ab..5119c611d3f3d 100644 --- a/onnxruntime/core/providers/openvino/ov_factory.cc +++ b/onnxruntime/core/providers/openvino/ov_factory.cc @@ -16,7 +16,7 @@ #include "onnxruntime_c_api.h" #include "ov_factory.h" #include "openvino/openvino.hpp" -#include "ov_interface.h" +#include "weak_singleton.h" using namespace onnxruntime::openvino_ep; using ov_core_singleton = onnxruntime::openvino_ep::WeakSingleton; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index e97bbaceee4e2..85fc4d93d6243 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -199,8 +199,8 @@ OVExeNetwork OVCore::ImportModel(ModelBlobWrapper& model_blob, return OvExceptionBoundary([&]() { ov::CompiledModel obj; #if (OPENVINO_VERSION_MAJOR > 2025 || (OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR >= 3)) - if (!model_blob.maybe_native_blob_path_.empty()) { - obj = core.import_model(ov::read_tensor_data(model_blob.maybe_native_blob_path_), hw_target, device_config); + if (model_blob.tensor_) { + obj = core.import_model(model_blob.tensor_, hw_target, device_config); } else { obj = core.import_model(*model_blob.stream_, hw_target, device_config); } diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index d5d4bd1af0c6a..5df5420a427f2 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -18,6 +18,7 @@ #include "openvino/frontend/manager.hpp" #include "openvino/core/dimension.hpp" #include "openvino/core/partial_shape.hpp" +#include "weak_singleton.h" #include @@ -47,31 +48,6 @@ typedef std::shared_ptr OVTensorPtr; std::optional queryOVProperty(const std::string& property, const std::string& device_type); -template -class WeakSingleton { - public: - static std::shared_ptr Get() { - static std::weak_ptr instance; - static std::mutex mutex; - - auto ptr = instance.lock(); - if (!ptr) { - std::lock_guard lock(mutex); - // ensure another thread didn't create an instance while this thread was waiting - ptr = instance.lock(); - if (!ptr) { - ptr = std::make_shared(); - instance = ptr; - } - } - return ptr; - } - - protected: - WeakSingleton() = default; - virtual ~WeakSingleton() = default; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeakSingleton); -}; struct OVCore : WeakSingleton { ov::Core core; @@ -153,7 +129,7 @@ class OVInferRequest { virtual void Infer(); explicit OVInferRequest(ov::InferRequest obj) : ovInfReq(std::move(obj)) {} OVInferRequest() : ovInfReq(ov::InferRequest()) {} - ov::InferRequest& GetNewObj() { + ov::InferRequest& GetInfReq() { return ovInfReq; } virtual void RewindKVCache([[maybe_unused]] size_t index) {} diff --git a/onnxruntime/core/providers/openvino/ov_shared_context.cc b/onnxruntime/core/providers/openvino/ov_shared_context.cc new file mode 100644 index 0000000000000..84cce6e7e16d4 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_shared_context.cc @@ -0,0 +1,145 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "ov_shared_context.h" +#include "ov_interface.h" + +#include "openvino/runtime/intel_npu/level_zero/level_zero.hpp" +#include "openvino/core/type/element_type.hpp" + +namespace onnxruntime { +namespace openvino_ep { + +SharedContext::SharedContext(std::filesystem::path bin_path) + : bin_path_(std::move(bin_path)), + bin_manager_(bin_path_) { +} + +static bool InRange(size_t offset, size_t size, size_t total_size) { + return (offset < total_size) && (size <= total_size) && (offset <= total_size - size); +} + +// Weights file handling +SharedContext::WeightsFile::WeightsFile(const std::filesystem::path& filename) : file_(filename, std::ios::in | std::ios::binary), file_path_(filename) { + try { + file_.exceptions(std::ifstream::failbit | std::ifstream::badbit); + weights_size_ = std::filesystem::file_size(filename); + } catch (std::exception& e) { + ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what()); + } +} + +void SharedContext::WeightsFile::LoadWeights(size_t file_offset, void* data, size_t size) { + ORT_ENFORCE(InRange(file_offset, size, weights_size_), "Error: File offset is out of bounds."); + file_.seekg(file_offset); + file_.read(static_cast(data), size); +} + +void* SharedContext::WeightsFile::TryGetOrCreateDeviceMapping(std::optional& remote_context) { + std::string dev_name{}; + if (remote_context) { + dev_name = remote_context->get_device_name(); + } + + auto [it, inserted] = imported_device_tensors_.emplace(dev_name, MappingContainer{}); + if (inserted) { + if (dev_name == "NPU") { + // try to import the memory mapped file to remote tensor +#if (OPENVINO_VERSION_MAJOR > 2025 || (OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR >= 3)) + ORT_ENFORCE(remote_context, "Error: Remote context is required for NPU device."); + auto npu_context = remote_context->as(); + auto&& l0_tensor = npu_context.create_tensor(ov::element::Type_t::u8, {weights_size_}, ov::intel_npu::FileDescriptor(file_path_)); + it->second = MappingContainer{.ptr_ = l0_tensor.get(), .tensor_ = l0_tensor}; +#endif + } else if (dev_name.empty()) { + // CPU/virtual device case, create a CPU tensor memory mapped from file + auto&& mmaped_tensor = ov::read_tensor_data(file_path_); + it->second = MappingContainer{.ptr_ = mmaped_tensor.data(), .tensor_ = mmaped_tensor}; + } + } + + return it->second.ptr_; +} + +void SharedContext::LoadTensorFromFile( + Metadata::Value& value, + const std::filesystem::path& model_dir, + std::optional& remote_context, + const ov::element::Type& element_type, + const ov::Shape& dimensions) { + const auto weights_location = model_dir / value.serialized.location; + auto& weights_file = weight_files_[weights_location]; + if (!weights_file) { + weights_file = std::make_unique(weights_location); + } + + ov::Tensor tensor; + uint8_t* mmaped_weights = static_cast(weights_file->TryGetOrCreateDeviceMapping(remote_context)); + if (mmaped_weights) { + // We have memory mapped weights. Create a Tensor view into it for this value. + ORT_ENFORCE(InRange(value.serialized.data_offset, value.serialized.size, weights_file->Size()), "File offset + size outside of external initializer file"); + void* mmapped_offset = static_cast(mmaped_weights + value.serialized.data_offset); + tensor = ov::Tensor(element_type, dimensions, mmapped_offset); + } else { + ORT_ENFORCE(remote_context, "Unexpected: Don't have remote context and memory mapped weights is null!"); + // Can't mmap the file to device tensor, create a host tensor and copy the data + tensor = remote_context->create_host_tensor(element_type, dimensions); + ORT_ENFORCE(tensor.get_byte_size() == value.serialized.size, "Remote tensor size mismatch"); + weights_file->LoadWeights(value.serialized.data_offset, tensor.data(), value.serialized.size); + } + + ORT_ENFORCE(tensor.get_byte_size() == value.serialized.size, "Tensor size mismatch"); + value.tensor = std::make_shared(std::move(tensor)); +} + +void SharedContext::SetSharedWeightsOnInferRequest(ov::InferRequest& ir, const std::filesystem::path& model_dir) { + auto&& compiled_model = ir.get_compiled_model(); + std::optional opt_remote_ctx; + try { + opt_remote_ctx = compiled_model.get_context(); + } catch (ov::Exception&) { + // CPU may not have a remote context. + } + + std::unique_lock ul(mutex_); + for (const auto& input : compiled_model.inputs()) { + const std::string tensor_name = *input.get_names().begin(); + + auto it = metadata_.find(tensor_name); + if (it == metadata_.end()) continue; // No shared weight for this tensor + auto& value = it->second; + + if (!value.tensor) { + LoadTensorFromFile(value, model_dir, opt_remote_ctx, input.get_element_type(), input.get_shape()); + } + ir.set_tensor(tensor_name, *value.tensor); + } +} + +void SharedContext::Serialize(std::ostream& stream) { + bin_manager_.Serialize(stream, shared_from_this()); +} + +void SharedContext::Deserialize(std::istream& stream) { + bin_manager_.Deserialize(stream, shared_from_this()); +} + +void SharedContext::Serialize() { + bin_manager_.Serialize(shared_from_this()); +} + +void SharedContext::Deserialize() { + bin_manager_.Deserialize(shared_from_this()); +} + +void SharedContext::Clear() { + // Outside the mutex since bin_manager has it's own lock, and we want to keep lock ordering consistent + // It's ok for clear to not be fully atomic we're primarily interested in internal consistency. + bin_manager_.Clear(); + std::unique_lock lock(mutex_); + weight_files_.clear(); + metadata_.clear(); +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_shared_context.h b/onnxruntime/core/providers/openvino/ov_shared_context.h new file mode 100644 index 0000000000000..c893b64442fa4 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_shared_context.h @@ -0,0 +1,159 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/runtime/core.hpp" +#include "ov_bin_manager.h" +#include "weak_singleton.h" + +namespace onnxruntime { +namespace openvino_ep { + +class SharedContext : public std::enable_shared_from_this { + public: + explicit SharedContext(std::filesystem::path bin_path); + SharedContext() : SharedContext("") {} + + struct Metadata { + struct Value { + struct { + std::filesystem::path location{}; + size_t data_offset{0}; + size_t size{0}; + } serialized; + + std::shared_ptr tensor; + }; + using Map = std::unordered_map; + }; + + bool IsSharedWeight(const std::string& name) const { + std::shared_lock lock(mutex_); + return metadata_.contains(name); + } + + void AddExternalWeight(const std::string& name, size_t offset, size_t size, const std::filesystem::path& location) { + Metadata::Value value; + value.serialized.data_offset = offset; + value.serialized.size = size; + value.serialized.location = location; + std::unique_lock lock(mutex_); + metadata_[name] = std::move(value); + } + + Metadata::Map GetMetadataCopy() const { + std::shared_lock lock(mutex_); + return metadata_; + } + + void SetSharedWeightsOnInferRequest(ov::InferRequest& ir, const std::filesystem::path& model_dir); + + void AddNativeBlob(const std::string& name, const ov::CompiledModel& compiled_model) { + bin_manager_.AddNativeBlob(name, compiled_model); + } + + ov::Tensor GetNativeBlob(const std::string& blob_name) { + return bin_manager_.GetNativeBlob(blob_name); + } + + std::unique_ptr GetNativeBlobAsStream(const std::string& blob_name) { + return bin_manager_.GetNativeBlobAsStream(blob_name); + } + + void Serialize(std::ostream& stream); + void Deserialize(std::istream& stream); + void Serialize(); + void Deserialize(); + + void Clear(); + + std::filesystem::path GetBinPath() const { + return bin_manager_.GetExternalBinPath(); + } + + static std::filesystem::path GetBinPathForModel(const std::filesystem::path& model_path) { + return BinManager::GetBinPathForModel(model_path); + } + + private: + struct WeightsFile { + ORT_DISALLOW_COPY_AND_ASSIGNMENT(WeightsFile); + WeightsFile() = delete; + virtual ~WeightsFile() = default; + explicit WeightsFile(const std::filesystem::path& filename); + void LoadWeights(size_t file_offset, void* data, size_t size); + void* TryGetOrCreateDeviceMapping(std::optional& remote_context); + size_t Size() const { return weights_size_; } + + private: + std::ifstream file_; + std::filesystem::path file_path_; + size_t weights_size_; + struct MappingContainer { + void* ptr_{nullptr}; + ov::Tensor tensor_; + }; + std::map imported_device_tensors_; + }; + + void LoadTensorFromFile( + Metadata::Value& value, + const std::filesystem::path& model_dir, + std::optional& remote_context, + const ov::element::Type& element_type, + const ov::Shape& dimensions); + + mutable std::shared_mutex mutex_; + std::filesystem::path bin_path_; + BinManager bin_manager_; + std::unordered_map> weight_files_; + Metadata::Map metadata_; +}; + +class SharedContextManager : public WeakSingleton { + public: + std::shared_ptr GetOrCreateActiveSharedContext(const std::filesystem::path& model_path) { + std::lock_guard lock(mutex_); + if (active_context_) { + return active_context_; + } + auto [it, inserted] = contexts_.try_emplace(model_path, nullptr); + if (inserted) { + it->second = std::make_shared(model_path); + } + active_context_ = it->second; + return it->second; + } + + std::shared_ptr GetOrCreateSharedContext(const std::filesystem::path& model_path) { + std::lock_guard lock(mutex_); + auto [it, inserted] = contexts_.try_emplace(model_path, nullptr); + if (inserted) { + it->second = std::make_shared(model_path); + } + return it->second; + } + + void ClearActiveSharedContext() { + std::lock_guard lock(mutex_); + active_context_ = nullptr; + } + + private: + mutable std::mutex mutex_; + std::unordered_map> contexts_; + std::shared_ptr active_context_; +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index e010851f22e50..2e5bb7b8c86be 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -704,7 +704,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, bool enable_ovep_weight_sharing, bool enable_ovep_qdq_optimizer, /*out*/ std::unique_ptr& model, - /*out*/ sw& shared_weights) { + /*out*/ SharedContext& shared_context) { // NOTE: This function is a re-implementation of GraphViewerToProto() in core/graph/graph_proto_serializer.cc // with the following differences: // - Uses onnxruntime::Graph APIs instead of onnx::GraphProto APIs. @@ -824,34 +824,28 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, }); // initialize map for creating metadata for initilizers with external weights - auto& metadata = shared_weights.metadata; - - const auto& insert_metadata = [&metadata](const ONNX_NAMESPACE::TensorProto& proto) { - sw::Metadata::Map::key_type key{proto.name()}; - sw::Metadata::Map::mapped_type value{}; + const auto& add_shared_weight = [&shared_context](const ONNX_NAMESPACE::TensorProto& proto) { using mutable_proto_t = ONNX_NAMESPACE::TensorProto*; auto& mutable_proto = *const_cast(&proto); auto* entry_protos = mutable_proto.mutable_external_data(); + + std::string location = ""; + size_t data_offset = 0, size = 0; for (int i = 0; i < entry_protos->size(); i++) { auto& string_entry_proto{entry_protos->at(i)}; const auto& pb_key{*(string_entry_proto.mutable_key())}; const auto& pb_value{*(string_entry_proto.mutable_value())}; if (pb_key == "location") { - value.location = pb_value; + location = pb_value; } else if (pb_key == "offset") { - value.data_offset = std::stoul(pb_value); + data_offset = std::stoul(pb_value); } else if (pb_key == "length") { - value.size = std::stoul(pb_value); + size = std::stoul(pb_value); } } - value.element_type = proto.data_type(); - value.dimensions.resize(proto.dims_size()); - for (uint32_t index = 0; auto& dim : value.dimensions) { - dim = proto.dims()[index++]; - } - metadata.emplace(key, std::move(value)); + shared_context.AddExternalWeight(proto.name(), data_offset, size, location); }; // Handle initializers @@ -871,7 +865,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, if (!is_quant_param) { // This is actual weight data - so to convert to input for weight sharing - insert_metadata(initializer_tensor); + add_shared_weight(initializer_tensor); AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); } else { // This is a quantization parameter - keep as initializer even if external @@ -912,7 +906,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, if (!init_with_data && utils::HasExternalData(initializer_tensor) && enable_ovep_weight_sharing) { - insert_metadata(initializer_tensor); + add_shared_weight(initializer_tensor); // Add initializer as input if it has external data AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, input->Name()); diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h index 53de0fd019311..e649b3ec71943 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace openvino_ep { -using sw = SharedContext::SharedWeights; +class SharedContext; // Creates a new model without the DQ/Q operators in the src graph as per pre-defined rulesets Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, @@ -18,8 +18,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, bool enable_ovep_weight_sharing, bool enable_ovep_qdq_optimizer, /*out*/ std::unique_ptr& model, - /*out*/ sw& shared_weights); + /*out*/ SharedContext& shared_context); -bool dumpMetaDataMapToBinary(const sw::Metadata::Map& shared_weights, const std::string& filename); } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/weak_singleton.h b/onnxruntime/core/providers/openvino/weak_singleton.h new file mode 100644 index 0000000000000..949ed1b527c60 --- /dev/null +++ b/onnxruntime/core/providers/openvino/weak_singleton.h @@ -0,0 +1,40 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include "core/common/common.h" + +namespace onnxruntime { +namespace openvino_ep { + +template +class WeakSingleton { + public: + static std::shared_ptr Get() { + static std::weak_ptr instance; + static std::mutex mutex; + + auto ptr = instance.lock(); + if (!ptr) { + std::lock_guard lock(mutex); + // ensure another thread didn't create an instance while this thread was waiting + ptr = instance.lock(); + if (!ptr) { + ptr = std::make_shared(); + instance = ptr; + } + } + return ptr; + } + + protected: + WeakSingleton() = default; + virtual ~WeakSingleton() = default; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeakSingleton); +}; + +} // namespace openvino_ep +} // namespace onnxruntime From 51493cd9fb9d8266ece5ae166736b9d81c75bfb4 Mon Sep 17 00:00:00 2001 From: Jaswanth51 Date: Thu, 13 Nov 2025 22:27:37 -0800 Subject: [PATCH 125/138] Run lintrunner and fix formatting issues (#849) --- .../providers/openvino/openvino_provider_factory.cc | 2 +- onnxruntime/core/providers/openvino/ov_interface.h | 1 - .../providers/openvino/ov_stateful_patch_utils.cc | 12 ++++++------ .../openvino/qdq_transformations/qdq_scales_fix.cpp | 2 +- onnxruntime/python/onnxruntime_pybind_state.cc | 8 +++----- 5 files changed, 11 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index cb94fb3793024..7eb5b062fe7c8 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -193,7 +193,7 @@ static void ParseInnerMap(const nlohmann::json& json_map, ov::AnyMap& inner_map, const size_t max_levels = 8; if (level >= max_levels) { ORT_THROW("ParseInnerMap: load_config can have only up to " + std::to_string(max_levels) + - " levels of nested maps. Current level = " + std::to_string(level)); + " levels of nested maps. Current level = " + std::to_string(level)); } if (!json_map.is_object()) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 5df5420a427f2..8765cd040d098 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -48,7 +48,6 @@ typedef std::shared_ptr OVTensorPtr; std::optional queryOVProperty(const std::string& property, const std::string& device_type); - struct OVCore : WeakSingleton { ov::Core core; diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index 7f276f565f795..20c8deb0698be 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -60,7 +60,7 @@ bool ModelHasInputOutputNames(std::shared_ptr model, const std::strin } std::string GetInputOutputName(std::shared_ptr ov_model, - const std::vector& candidate_names) { + const std::vector& candidate_names) { for (const auto& name : candidate_names) { if (ModelHasInputOutputNames(ov_model, name)) { return name; @@ -78,12 +78,12 @@ void FuseCacheReorder(std::shared_ptr ov_model, throw std::runtime_error("Model already has fused cache"); } - // Define input name candidates in priority order + // Define input name candidates in priority order const std::vector input_name_candidates = { - "inputs_embeds", // Default fallback - "input_ids", // Most common - "input_hidden_states", // Alternative - "/model/embed_tokens/Gather_output_0" // Specific model type + "inputs_embeds", // Default fallback + "input_ids", // Most common + "input_hidden_states", // Alternative + "/model/embed_tokens/Gather_output_0" // Specific model type }; std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates); diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index a7b5c51882ff4..84d391a3f2ff3 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -472,7 +472,7 @@ struct CustomGraph { ORT_ENFORCE(type_str != nullptr, "Type string is null in QDQ scales fix."); auto type_cast = type_str->find("tensor(float)") != std::string::npos ? onnx::TensorProto_DataType_FLOAT : onnx::TensorProto_DataType_FLOAT16; ORT_ENFORCE((type_cast == onnx::TensorProto_DataType_FLOAT) || (type_str->find("tensor(float16)") != std::string::npos), - "QDQ type misalignment, expected float32 or float16 output"); + "QDQ type misalignment, expected float32 or float16 output"); cast_node.AddAttribute("to", static_cast(type_cast)); original_graph.AddEdge(prev.node_ptr->Index(), cast_node.Index(), diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 704716e80eb1d..92cf6b085c01e 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2810,8 +2810,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); #endif }) - .def( - "set_ep_dynamic_options", [](PyInferenceSession* sess, const py::dict& options) { + .def("set_ep_dynamic_options", [](PyInferenceSession* sess, const py::dict& options) { std::vector keys; std::vector values; std::vector key_strings; @@ -2841,9 +2840,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") if (!status.IsOK()) { ORT_THROW("Failed to set EP dynamic options: " + status.ErrorMessage()); - } - }, - R"pbdoc(Set dynamic options for execution providers. + } }, + R"pbdoc(Set dynamic options for execution providers. Args: options (dict): Dictionary of key-value pairs where both keys and values are strings. From d0bac3eb16aa7515557541f14f18d15550791d4c Mon Sep 17 00:00:00 2001 From: Yaru Du Date: Sat, 15 Nov 2025 00:23:23 +0000 Subject: [PATCH 126/138] CVS-175736 - [OVEP] Optimize Stateful Path: use output-to-input strategy to get the pairs of KV name (#845) * use output-to-input strategy to get the pairs of KV name * minor change * remove regex for extracting pattern * Address review * Design strict KV patterns: only two separately for key and value; patterns have to be followed by _%d * simplify code structure * address review * remove useless comment * add brief example to explain the functionalities --------- Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../openvino/ov_stateful_patch_utils.cc | 145 ++++++++++++++---- 1 file changed, 116 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index 20c8deb0698be..c4ec47534d009 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License #include "core/providers/openvino/ov_stateful_patch_utils.h" +#include "core/providers/shared_library/provider_api.h" +#include "core/common/common.h" namespace onnxruntime { namespace openvino_ep { @@ -132,29 +134,109 @@ void MakeStateful(std::shared_ptr& ov_model, manager.run_passes(ov_model); } -// Converted to C++ from below reference URL: -// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281 -void PatchStatefulDecoder(std::shared_ptr model) { +// Helper function to extract KV patterns from output names dynamically +// +// Example: Given output names ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1", "logits"] +// key_value_output_names = ["present_key_cross_0", "present_key_cross_1", "present_value_cross_0", "present_value_cross_1"] +// unique_patterns = {"key_cross", "value_cross"} +std::pair, std::unordered_set> ExtractKVPatternsFromOutputs(const std::shared_ptr& model) { + std::vector key_value_output_names; + std::unordered_set unique_patterns; + + const std::string prefix = "present_"; + const size_t prefix_len = prefix.length(); + for (const ov::Output& output : model->outputs()) { + const auto& names = output.get_names(); + for (const auto& name : names) { + if (name.find(prefix) == 0 && name.length() > prefix_len) { + size_t last_underscore_pos = name.rfind('_'); + // Extract pattern between "present_" and the last underscore + if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) { + std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len); + if (!pattern.empty()) { + unique_patterns.insert(pattern); + key_value_output_names.push_back(name); + } + } + break; + } + } + } + + if (unique_patterns.size() > 2) { + ORT_THROW("More than two unique KV patterns found in output names."); + } + return std::make_pair(key_value_output_names, unique_patterns); +} + +// Main function to extract KV tensors using dynamic pattern matching +// +// Example: Given input names ["input_ids", "attention_mask", "past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"] +// kv_patterns = {"key_cross", "value_cross"} +// +// key_value_input_names = ["past_key_cross_0", "past_key_cross_1", "past_value_cross_0", "past_value_cross_1"] +// not_kv_inputs = ["input_ids", "attention_mask"] +std::pair, std::vector> ExtractInputKVTensors( + const std::shared_ptr& model, const std::unordered_set& kv_patterns) { + std::vector key_value_input_names; std::vector not_kv_inputs; + + if (kv_patterns.empty()) { + // Fallback: use original substring matching + for (const ov::Output& input : model->inputs()) { + const auto& names = input.get_names(); + const std::string input_name = input.get_any_name(); + + bool is_kv_input = false; + for (const auto& name : names) { + if (name.find("key_values") != std::string::npos || + name.find("keys") != std::string::npos || + name.find("values") != std::string::npos) { + key_value_input_names.push_back(name); + is_kv_input = true; + break; + } + } + + if (!is_kv_input) { + not_kv_inputs.push_back(input_name); + } + } + + return std::make_pair(key_value_input_names, not_kv_inputs); + } + + // Inline helper function to check if name is matched with provided pattern followed by "_%d" + auto matches_pattern = [](const std::string& name, const std::string& pattern) -> bool { + size_t pos = name.find(pattern); + if (pos == std::string::npos) { + return false; + } + + size_t after_pattern = pos + pattern.length(); + if (after_pattern >= name.length() || name[after_pattern] != '_') { + return false; + } + + std::string suffix = name.substr(after_pattern + 1); + return !suffix.empty() && std::all_of(suffix.begin(), suffix.end(), ::isdigit); + }; + for (const ov::Output& input : model->inputs()) { auto& names = input.get_names(); - bool found = false; - for (auto& name : names) { - if (name.find("key_values") != std::string::npos) { - key_value_input_names.push_back(name); - found = true; - break; - } else if (name.find("keys") != std::string::npos) { - key_value_input_names.push_back(name); - found = true; - break; - } else if (name.find("values") != std::string::npos) { - key_value_input_names.push_back(name); - found = true; - break; + + // Check if any input name contains either key or value pattern + for (const auto& name : names) { + for (const auto& pattern : kv_patterns) { + if (matches_pattern(name, pattern)) { + key_value_input_names.push_back(name); + found = true; + break; + } } + if (found) break; } if (!found) { @@ -162,20 +244,25 @@ void PatchStatefulDecoder(std::shared_ptr model) { } } - std::vector key_value_output_names; - for (const ov::Output& output : model->outputs()) { - auto& names = output.get_names(); - for (auto& name : names) { - if (name.find("present") != std::string::npos) { - key_value_output_names.push_back(name); - break; - } - } - } + return std::make_pair(key_value_input_names, not_kv_inputs); +} + +// Updated PatchStatefulDecoder function +void PatchStatefulDecoder(std::shared_ptr model) { + // Use the dynamic pattern-based extraction logic + auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model); + auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns); if (key_value_input_names.empty() || key_value_output_names.empty()) { - std::cout << "no key_value_input_names or key_value_output_names found" << std::endl; - return; + ORT_THROW("No key_value_input_names or key_value_output_names found"); + } + + if (key_value_input_names.size() != key_value_output_names.size()) { + ORT_THROW("Found different sizes between key_value_input_names (", + key_value_input_names.size(), + ") and key_value_output_names (", + key_value_output_names.size(), + "). They couldn't be paired."); } // By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch From 75c11ae31e198e233395fe2742baf3f99d4fc0b2 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Mon, 17 Nov 2025 09:10:04 -0800 Subject: [PATCH 127/138] CVS-167480 : Report import failure error code (#715) * Catch model import failure and report the appropriate error * Address review comments --------- Co-authored-by: ankitm3k Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../providers/openvino/backend_manager.cc | 40 +--- .../core/providers/openvino/exceptions.h | 88 +++++++++ .../openvino/openvino_execution_provider.cc | 181 +++++++++--------- .../core/providers/openvino/ov_interface.cc | 65 ++++--- .../core/providers/openvino/ov_interface.h | 2 + 5 files changed, 226 insertions(+), 150 deletions(-) create mode 100644 onnxruntime/core/providers/openvino/exceptions.h diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index abb5b31b76e44..eed08ee673e49 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -21,6 +21,7 @@ #include "core/providers/openvino/ov_interface.h" #include "core/providers/openvino/ov_versions/capability.h" #include "core/providers/openvino/qdq_transformations/qdq_stripping.h" +#include "core/providers/openvino/exceptions.h" #include "core/providers/openvino/qdq_transformations/qdq_scales_fix.h" #include "../../framework/tensorprotoutils.h" @@ -157,40 +158,11 @@ BackendManager::BackendManager(SessionContext& session_context, subgraph_context_.has_dynamic_input_shape = false; // OV NPU plugin is supported with fallback to OV CPU upon compilation failures. - try { - concrete_backend_ = BackendFactory::MakeBackend(model_proto, - session_context_, - subgraph_context_, - *shared_context_, - model_stream); - } catch (const OnnxRuntimeException& ex) { - std::string exception_str = ex.what(); - - if (session_context_.device_type.find("NPU") != std::string::npos && - exception_str.find("intel_npu") != std::string::npos) { - // Handle NPU device related errors -#ifndef NDEBUG - std::string suffix = session_context_.so_disable_cpu_ep_fallback ? "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n" : "\nModel needs to be recompiled\n"; - ORT_THROW(exception_str + suffix); -#else - std::string error_message = "UNKNOWN NPU ERROR"; - std::string error_code = "code 0x0"; - std::regex error_message_pattern(R"(\bZE_\w*\b)"); - std::regex error_code_pattern("code 0x[0-9a-fA-F]+"); - std::smatch matches; - if (std::regex_search(exception_str, matches, error_message_pattern)) { - error_message = matches[0]; - } - if (std::regex_search(exception_str, matches, error_code_pattern)) { - error_code = matches[0]; - } - std::string suffix = session_context_.so_disable_cpu_ep_fallback ? "\nModel failed to compile on NPU. Enable CPU fallback or try another device.\n" : "\nModel needs to be recompiled\n"; - throw std::runtime_error(error_message + ", " + error_code + suffix); -#endif - } else { - ORT_THROW(exception_str); - } - } + concrete_backend_ = BackendFactory::MakeBackend(model_proto, + session_context_, + subgraph_context_, + *shared_context_, + model_stream); } if (ShouldExportEpContext(session_context_, subgraph_context_)) { diff --git a/onnxruntime/core/providers/openvino/exceptions.h b/onnxruntime/core/providers/openvino/exceptions.h new file mode 100644 index 0000000000000..140ab1ac688ba --- /dev/null +++ b/onnxruntime/core/providers/openvino/exceptions.h @@ -0,0 +1,88 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include + +#include "core/common/status.h" + +namespace onnxruntime { +namespace openvino_ep { + +struct ovep_exception : public std::exception { + enum class type { + compile_model, + import_model, + query_prop, + read_model, + unknown, + }; + + ovep_exception(const std::exception& ex, enum class type exception_type) + : message_{ex.what()}, + type_{exception_type}, + error_code_{ze_result_code_from_string(message_)}, + error_name_{ze_result_name_from_string(message_)} {} + + ovep_exception(const std::string& message, enum class type exception_type) + : message_{message}, + type_{exception_type}, + error_code_{ze_result_code_from_string(message)}, + error_name_{ze_result_name_from_string(message)} {} + + const char* what() const noexcept override { + return message_.data(); + } + + uint32_t get_code() const { return error_code_; } + + operator common::Status() const { + common::StatusCategory category_ort{common::ONNXRUNTIME}; + + if (type_ == type::unknown) { + return {category_ort, common::FAIL, message_}; + } + + // Newer drivers + if ((type_ == type::import_model) && + (error_code_ == 0x7800000f /* ZE_RESULT_ERROR_INVALID_NATIVE_BINARY */)) { + std::string message{error_name_ + ", code 0x" + std::to_string(error_code_) + "\nModel needs to be recompiled\n"}; + return {category_ort, common::INVALID_GRAPH, message}; + } + + std::string error_message = "Unhandled exception type: " + std::to_string(static_cast(type_)); + return {category_ort, common::EP_FAIL, error_message}; + } + + protected: + std::string message_; + type type_{type::unknown}; + uint32_t error_code_{0}; + std::string error_name_; + + private: + uint32_t ze_result_code_from_string(const std::string& ov_exception_string) { + uint32_t error_code{0}; + std::regex error_code_pattern("code 0x([0-9a-fA-F]+)"); + std::smatch matches; + if (std::regex_search(ov_exception_string, matches, error_code_pattern)) { + std::from_chars(&(*matches[1].first), &(*matches[1].second), error_code, 16); + } + return error_code; + } + std::string ze_result_name_from_string(const std::string& ov_exception_string) { + std::string error_message = "UNKNOWN NPU ERROR"; + std::regex error_message_pattern(R"(\bZE_\w*\b)"); + std::smatch matches; + if (std::regex_search(ov_exception_string, matches, error_message_pattern)) { + error_message = matches[0]; + } + return error_message; + } +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index f9c9fa2ea6f48..6dc7328d696da 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -12,6 +12,7 @@ #include "core/providers/openvino/onnx_ctx_model_helper.h" #include "core/providers/openvino/ov_versions/capability.h" #include "core/providers/openvino/qdq_transformations/qdq_stripping.h" +#include "core/providers/openvino/exceptions.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "openvino/core/version.hpp" #ifdef USE_OVEP_NPU_MEMORY @@ -103,107 +104,111 @@ common::Status OpenVINOExecutionProvider::Compile( auto& logger = *GetLogger(); Status status = Status::OK(); - if (session_context_.so_context_enable && session_context_.so_context_embed_mode && session_context_.so_share_ep_contexts) { - return Status(common::StatusCategory::ONNXRUNTIME, common::EP_FAIL, - std::string("Invalid EP context configuration: ") + kOrtSessionOptionEpContextEmbedMode + " must be 0 if " + kOrtSessionOptionShareEpContexts + " is 1."); - } + try { + if (session_context_.so_context_enable && session_context_.so_context_embed_mode && session_context_.so_share_ep_contexts) { + return Status(common::StatusCategory::ONNXRUNTIME, common::EP_FAIL, + std::string("Invalid EP context configuration: ") + kOrtSessionOptionEpContextEmbedMode + " must be 0 if " + kOrtSessionOptionShareEpContexts + " is 1."); + } - bool is_epctx_model = false; - if (!fused_nodes.empty()) { - // Assume these properties are constant for all the model subgraphs, otherwise move to SubGraphContext - const auto& graph_body_viewer_0 = fused_nodes[0].filtered_graph.get(); - session_context_.onnx_model_path_name = graph_body_viewer_0.ModelPath().string(); - session_context_.onnx_opset_version = - graph_body_viewer_0.DomainToVersionMap().at(kOnnxDomain); - - // OVIR wrapped in epctx should be treated as source but this code does not - // This corner case is not in use and will be addressed in a future commit - is_epctx_model = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(graph_body_viewer_0); - } + bool is_epctx_model = false; + if (!fused_nodes.empty()) { + // Assume these properties are constant for all the model subgraphs, otherwise move to SubGraphContext + const auto& graph_body_viewer_0 = fused_nodes[0].filtered_graph.get(); + session_context_.onnx_model_path_name = graph_body_viewer_0.ModelPath().string(); + session_context_.onnx_opset_version = + graph_body_viewer_0.DomainToVersionMap().at(kOnnxDomain); + + // OVIR wrapped in epctx should be treated as source but this code does not + // This corner case is not in use and will be addressed in a future commit + is_epctx_model = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(graph_body_viewer_0); + } - if (is_epctx_model) { - ep_ctx_handle_.Initialize(fused_nodes, session_context_.GetOutputBinPath().parent_path()); - } + if (is_epctx_model) { + ep_ctx_handle_.Initialize(fused_nodes, session_context_.GetOutputBinPath().parent_path()); + } - struct OpenVINOEPFunctionState { - AllocateFunc allocate_func = nullptr; - DestroyFunc destroy_func = nullptr; - AllocatorHandle allocator_handle = nullptr; - BackendManager& backend_manager; - }; - - for (const FusedNodeAndGraph& fused_node_graph : fused_nodes) { - const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; - const Node& fused_node = fused_node_graph.fused_node; - - NodeComputeInfo compute_info; - - // During backend creation, we check if user wants to use precompiled blob onnx model or the original model - // For precompiled blob, directly load the model instead of compiling the model - // For original model, check if the user wants to export a model with pre-compiled blob - - auto& backend_manager = backend_managers_.emplace_back(session_context_, - *shared_context_manager_, - fused_node, - graph_body_viewer, - logger, - ep_ctx_handle_); - compute_info.create_state_func = - [&backend_manager](ComputeContext* context, FunctionState* state) { - OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState{ - .allocate_func = context->allocate_func, - .destroy_func = context->release_func, - .allocator_handle = context->allocator_handle, - .backend_manager = backend_manager}; - *state = static_cast(p); - return 0; - }; - - compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) { - auto function_state = static_cast(state); - try { - function_state->backend_manager.Compute(context); - } catch (const std::exception& ex) { - return common::Status(common::ONNXRUNTIME, common::FAIL, ex.what()); - } - return Status::OK(); + struct OpenVINOEPFunctionState { + AllocateFunc allocate_func = nullptr; + DestroyFunc destroy_func = nullptr; + AllocatorHandle allocator_handle = nullptr; + BackendManager& backend_manager; }; - compute_info.release_state_func = - [](FunctionState state) { - if (state) { - OpenVINOEPFunctionState* function_state = static_cast(state); - delete function_state; - } - }; + for (const FusedNodeAndGraph& fused_node_graph : fused_nodes) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + const Node& fused_node = fused_node_graph.fused_node; + + NodeComputeInfo compute_info; + + // During backend creation, we check if user wants to use precompiled blob onnx model or the original model + // For precompiled blob, directly load the model instead of compiling the model + // For original model, check if the user wants to export a model with pre-compiled blob + + auto& backend_manager = backend_managers_.emplace_back(session_context_, + *shared_context_manager_, + fused_node, + graph_body_viewer, + logger, + ep_ctx_handle_); + compute_info.create_state_func = + [&backend_manager](ComputeContext* context, FunctionState* state) { + OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState{ + .allocate_func = context->allocate_func, + .destroy_func = context->release_func, + .allocator_handle = context->allocator_handle, + .backend_manager = backend_manager}; + *state = static_cast(p); + return 0; + }; + + compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) { + auto function_state = static_cast(state); + try { + function_state->backend_manager.Compute(context); + } catch (const std::exception& ex) { + return common::Status(common::ONNXRUNTIME, common::FAIL, ex.what()); + } + return Status::OK(); + }; - node_compute_funcs.push_back(std::move(compute_info)); - } + compute_info.release_state_func = + [](FunctionState state) { + if (state) { + OpenVINOEPFunctionState* function_state = static_cast(state); + delete function_state; + } + }; - // Export compiled blobs as EPContext nodes if context enable is set - if (session_context_.so_context_enable) { - auto backend_it = backend_managers_.begin(); - bool is_first = true; + node_compute_funcs.push_back(std::move(compute_info)); + } - for (const auto& fused_node_graph : fused_nodes) { - const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + // Export compiled blobs as EPContext nodes if context enable is set + if (session_context_.so_context_enable) { + auto backend_it = backend_managers_.begin(); + bool is_first = true; - // Set include_embed_data to true only for the first backend manager - backend_it->TryExportCompiledBlobAsEPCtxNode(graph_body_viewer, is_first); + for (const auto& fused_node_graph : fused_nodes) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; - is_first = false; - ++backend_it; - } + // Set include_embed_data to true only for the first backend manager + backend_it->TryExportCompiledBlobAsEPCtxNode(graph_body_viewer, is_first); + + is_first = false; + ++backend_it; + } - // bit clunky ideally we should try to fold this into ep context handler - if (!session_context_.so_context_embed_mode) { - auto shared_context = shared_context_manager_->GetOrCreateActiveSharedContext(session_context_.GetOutputBinPath()); - shared_context->Serialize(); - if (session_context_.so_stop_share_ep_contexts) { - shared_context_manager_->ClearActiveSharedContext(); - shared_context->Clear(); + // bit clunky ideally we should try to fold this into ep context handler + if (!session_context_.so_context_embed_mode) { + auto shared_context = shared_context_manager_->GetOrCreateActiveSharedContext(session_context_.GetOutputBinPath()); + shared_context->Serialize(); + if (session_context_.so_stop_share_ep_contexts) { + shared_context_manager_->ClearActiveSharedContext(); + shared_context->Clear(); + } } } + } catch (const ovep_exception& ex) { + status = ex; } return status; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 85fc4d93d6243..446ed098521cb 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -12,16 +12,21 @@ #include "core/providers/openvino/backends/basic_backend.h" #include "core/providers/openvino/ov_stateful_patch_utils.h" #include "core/providers/openvino/onnx_ctx_model_helper.h" +#include "core/providers/openvino/exceptions.h" namespace onnxruntime { namespace openvino_ep { -template +template inline auto OvExceptionBoundary(Func&& func, std::format_string&& fmt, Args&&... args) { try { return func(); } catch (const ov::Exception& e) { - ORT_THROW(log_tag + std::vformat(fmt.get(), std::make_format_args(args...)) + ": " + std::string(e.what())); + if constexpr (typed) { + throw ovep_exception(e, ovep_exception::type::import_model); + } else { + ORT_THROW(log_tag + std::vformat(fmt.get(), std::make_format_args(args...)) + ": " + std::string(e.what())); + } } catch (...) { ORT_THROW(log_tag + std::vformat(fmt.get(), std::make_format_args(args...))); } @@ -70,7 +75,7 @@ std::optional queryOVProperty(const std::string& property, const std::stri } std::shared_ptr OVCore::ReadModel(std::string&& model, const std::string& model_path) { - return OvExceptionBoundary([&]() { + return OvExceptionBoundary([&]() { std::istringstream modelStringStream(std::move(model)); std::istream& modelStream = modelStringStream; // Try to load with FrontEndManager @@ -88,7 +93,7 @@ std::shared_ptr OVCore::ReadModel(std::string&& model, const std::str ORT_THROW(log_tag + "Unknown exception while Reading network"); } }, - "Exception while Reading network"); + "Exception while Reading network"); } OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr& model, @@ -156,7 +161,7 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_netwo ov::AnyMap& device_config, bool enable_causallm, const std::string& name) { - return OvExceptionBoundary([&]() { + return OvExceptionBoundary([&]() { OVExeNetwork exe; if (enable_causallm) { auto mutable_model = ie_cnn_network->clone(); @@ -172,14 +177,14 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_netwo return exe; }, - "Exception while Loading Network for graph {}", name); + "Exception while Loading Network for graph {}", name); } OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, std::string& hw_target, ov::AnyMap& device_config, const std::string& name) { - return OvExceptionBoundary([&]() { + return OvExceptionBoundary([&]() { ov::CompiledModel obj; obj = core.compile_model(onnx_model, ov::Tensor(), hw_target, device_config); @@ -189,14 +194,14 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, OVExeNetwork exe(obj, hw_target); return exe; }, - "Exception while Loading Network for graph {}", name); + "Exception while Loading Network for graph {}", name); } OVExeNetwork OVCore::ImportModel(ModelBlobWrapper& model_blob, std::string hw_target, const ov::AnyMap& device_config, std::string name) { - return OvExceptionBoundary([&]() { + return OvExceptionBoundary([&]() { ov::CompiledModel obj; #if (OPENVINO_VERSION_MAJOR > 2025 || (OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR >= 3)) if (model_blob.tensor_) { @@ -205,7 +210,7 @@ OVExeNetwork OVCore::ImportModel(ModelBlobWrapper& model_blob, obj = core.import_model(*model_blob.stream_, hw_target, device_config); } #else - obj = core.import_model(*model_blob.stream_, hw_target, device_config); + obj = core.import_model(*model_blob.stream_, hw_target, device_config); #endif OVExeNetwork exe(obj, hw_target); @@ -214,7 +219,7 @@ OVExeNetwork OVCore::ImportModel(ModelBlobWrapper& model_blob, #endif return exe; }, - "Exception while Loading Network for graph {}", name); + "Exception while Loading Network for graph {}", name); } OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream& model_stream, @@ -222,7 +227,7 @@ OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream& model_stream, const ov::AnyMap& device_config, bool enable_causallm, std::filesystem::path model_file_path) { - return OvExceptionBoundary([&]() { + return OvExceptionBoundary([&]() { OVExeNetwork exe; bool isXML = backend_utils::IsModelStreamXML(model_stream); @@ -267,7 +272,11 @@ OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream& model_stream, #endif return exe; }, - "Exception while Loading Network from OVIR model file: {}", model_file_path.string()); + "Exception while Loading Network from OVIR model file: {}", model_file_path.string()); +} + +void OVCore::SetCache(const std::string& cache_dir_path) { + core.set_property(ov::cache_dir(cache_dir_path)); } std::vector OVCore::GetAvailableDevices() const { @@ -308,8 +317,12 @@ std::vector OVCore::GetAvailableDevices(const std::string& device_t return available_devices; } +void OVCore::SetStreams(const std::string& device_type, int num_streams) { + core.set_property(device_type, {ov::num_streams(num_streams)}); +} + std::shared_ptr OVExeNetwork::CreateInferRequest() { - return OvExceptionBoundary([&]() { + return OvExceptionBoundary([&]() { auto infReq = compiled_model_obj.create_infer_request(); std::shared_ptr ovInfReq; if (is_stateful_causallm) { @@ -320,31 +333,31 @@ std::shared_ptr OVExeNetwork::CreateInferRequest() { return ovInfReq; }, - "Exception while creating InferRequest object"); + "Exception while creating InferRequest object"); } OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) { - return OvExceptionBoundary([&]() { + return OvExceptionBoundary([&]() { auto tobj = ovInfReq.get_tensor(input_name); OVTensorPtr blob = std::make_shared(tobj); return blob; }, - " Cannot access IE Blob for input: {}", input_name); + " Cannot access IE Blob for input: {}", input_name); } std::string OVInferRequest::GetInputTensorName(uint32_t index) { - return OvExceptionBoundary([&]() -> const std::string& { + return OvExceptionBoundary([&]() { const auto& model = ovInfReq.get_compiled_model(); return *model.input(index).get_names().begin(); }, - " Cannot access IE Blob for input number: {}", index); + " Cannot access IE Blob for input number: {}", index); } void OVInferRequest::SetTensor(const std::string& name, OVTensorPtr& blob) { - OvExceptionBoundary([&]() { + OvExceptionBoundary([&]() { ovInfReq.set_tensor(name, *(blob.get())); }, - " Cannot set Remote Blob for output: {}", name); + " Cannot set Remote Blob for output: {}", name); } uint32_t OVInferRequest::GetNumInputs() { @@ -352,20 +365,16 @@ uint32_t OVInferRequest::GetNumInputs() { } void OVInferRequest::Infer() { - OvExceptionBoundary([&]() { + OvExceptionBoundary([&]() { ovInfReq.infer(); }, - "In Error Couldn't start Inference"); + "In Error Couldn't start Inference"); } StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) : OVInferRequest(std::move(infer_request)), target_device(device) { bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); - - // check if there is input_ids tensors and if the tensor type is int64, - // because logic prefill_use_full_chat_history is only for specific inputs and data type - auto input_ids_opt = FindTensor("input_ids"); - if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() == ov::element::i64) { + if (gpu_or_npu) { prefill_use_full_chat_history = true; } } diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 8765cd040d098..8a55fdcbd4fb4 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -81,6 +81,8 @@ struct OVCore : WeakSingleton { std::vector GetAvailableDevices() const; std::vector GetAvailableDevices(const std::string& device_type) const; + void SetCache(const std::string& cache_dir_path); + void SetStreams(const std::string& device_type, int num_streams); }; class OVExeNetwork { From d951954e7a8992c24f8d4c639423962fd92d1fa7 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Mon, 17 Nov 2025 12:35:59 -0800 Subject: [PATCH 128/138] CVS-175504: Additional single bin simplifications + fixes for bi-directional compatibility (#851) * Modify shared context lifetime * Provide more helpful error message when failing to deserialize bin * Remove unused clear functions * Remove unused variable --- .../providers/openvino/backend_manager.cc | 40 ++++---------- .../core/providers/openvino/backend_manager.h | 5 +- .../core/providers/openvino/contexts.h | 11 ++-- .../openvino/onnx_ctx_model_helper.cc | 45 ++++++---------- .../openvino/onnx_ctx_model_helper.h | 7 +-- .../openvino/openvino_execution_provider.cc | 17 ++---- .../openvino/openvino_execution_provider.h | 1 + .../core/providers/openvino/ov_bin_manager.cc | 52 +++++++------------ .../core/providers/openvino/ov_bin_manager.h | 3 +- .../providers/openvino/ov_shared_context.cc | 9 ---- .../providers/openvino/ov_shared_context.h | 8 ++- 11 files changed, 69 insertions(+), 129 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index eed08ee673e49..3426a2781bbc6 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -43,30 +43,18 @@ static bool ShouldExportEpContext(const SessionContext& session_context, const S } BackendManager::BackendManager(SessionContext& session_context, - SharedContextManager& shared_context_manager, + SharedContext& shared_context, const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger, EPCtxHandler& ep_ctx_handle) : ep_ctx_handle_(ep_ctx_handle), session_context_(session_context), - shared_context_manager_(shared_context_manager) { + shared_context_(shared_context) { subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph); // If the graph contains a OVIR wrapped node, we check if it has matching xml file name attribute subgraph_context_.is_ep_ctx_ovir_encapsulated = ep_ctx_handle_.CheckEPCacheContextAttribute(subgraph, session_context_.onnx_model_path_name.filename().replace_extension("xml").string()); - if (subgraph_context_.is_ep_ctx_graph && !subgraph_context_.is_ep_ctx_ovir_encapsulated) { - shared_context_ = ep_ctx_handle.GetSharedContextForEpContextSubgraph(subgraph, - session_context_.GetModelPath()); - } else if (session_context_.so_context_enable && session_context_.so_share_ep_contexts) { - shared_context_ = shared_context_manager_.GetOrCreateActiveSharedContext(session_context_.GetOutputBinPath()); - } else { - // Creating a shared context to satisfy backend. It won't be used for weight sharing. - // Don't make it the active share context since we don't actually want to share it. - shared_context_ = shared_context_manager_.GetOrCreateSharedContext(session_context_.GetOutputBinPath()); - } - ORT_ENFORCE(shared_context_, "Could not create a shared context."); - subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) { // return empty if graph has no inputs or if types are not one of FP32/FP16 // else assume the type of the first input @@ -138,7 +126,7 @@ BackendManager::BackendManager(SessionContext& session_context, concrete_backend_ = BackendFactory::MakeBackend(model_proto, session_context_, subgraph_context_, - *shared_context_, + shared_context_, model_stream); } catch (std::string const& msg) { ORT_THROW(msg); @@ -161,13 +149,13 @@ BackendManager::BackendManager(SessionContext& session_context, concrete_backend_ = BackendFactory::MakeBackend(model_proto, session_context_, subgraph_context_, - *shared_context_, + shared_context_, model_stream); } if (ShouldExportEpContext(session_context_, subgraph_context_)) { if (concrete_backend_) { - shared_context_->AddNativeBlob(subgraph_context_.subgraph_name, concrete_backend_->GetOVCompiledModel()); + shared_context_.AddNativeBlob(subgraph_context_.subgraph_name, concrete_backend_->GetOVCompiledModel()); } else { ORT_THROW( "Exporting dynamically compiled models at runtime is not supported. " @@ -193,19 +181,11 @@ void BackendManager::TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVi if (session_context_.so_context_embed_mode) { // Internal blob if (include_embed_data) { std::stringstream ss; - shared_context_->Serialize(ss); + shared_context_.Serialize(ss); model_blob_str = std::move(ss).str(); } } else { // External blob - // Build name by combining EpCtx model name (if available) and subgraph name. Model - // name is not available in when creating a session from memory - auto name = session_context_.so_context_file_path.stem().string(); - if (name.empty() && !graph_body_viewer.ModelPath().empty()) { - name = graph_body_viewer.ModelPath().stem().string(); - } - ORT_ENFORCE(!name.empty()); - - model_blob_str = shared_context_->GetBinPath().filename().string(); + model_blob_str = shared_context_.GetBinPath().filename().string(); } auto status = ep_ctx_handle_.AddOVEPCtxNodeToGraph(graph_body_viewer, @@ -521,7 +501,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, if ((session_context_.device_type.find("NPU") != std::string::npos) && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; - Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, *shared_context_); + Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); print_model_proto_duration(); @@ -788,7 +768,7 @@ void BackendManager::Compute(OrtKernelContext* context) { dynamic_backend = BackendFactory::MakeBackend(modelproto_with_concrete_shapes, session_context_, subgraph_context_, - *shared_context_, + shared_context_, model_stream); } catch (const OnnxRuntimeException& ex) { // Build option disables fallback to CPU on compilation failures with NPU. @@ -808,7 +788,7 @@ void BackendManager::Compute(OrtKernelContext* context) { dynamic_backend = BackendFactory::MakeBackend(modelproto_with_concrete_shapes, session_context_, subgraph_context_, - *shared_context_, + shared_context_, model_stream); } catch (std::string const& msg) { ORT_THROW(msg); diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index 64dadb6c2151b..716fe3ef4cc90 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -20,7 +20,7 @@ namespace openvino_ep { class BackendManager { public: BackendManager(SessionContext& session_context, - SharedContextManager& shared_context_manager, + SharedContext& shared_context, const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger, @@ -59,8 +59,7 @@ class BackendManager { SubGraphContext subgraph_context_; EPCtxHandler& ep_ctx_handle_; SessionContext& session_context_; - SharedContextManager& shared_context_manager_; - std::shared_ptr shared_context_; + SharedContext& shared_context_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index b14e05191dfaa..ebb716a64162c 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -97,11 +97,12 @@ struct SessionContext : ProviderInfo { return onnx_model_path_name.empty() ? so_context_file_path : onnx_model_path_name; } - const std::filesystem::path GetOutputBinPath() const { - std::filesystem::path bin_file_name = so_context_file_path; - if (bin_file_name.empty()) { - bin_file_name = onnx_model_path_name; - } + const std::filesystem::path& GetOutputModelPath() const { + return so_context_file_path.empty() ? onnx_model_path_name : so_context_file_path; + } + + std::filesystem::path GetOutputBinPath() const { + const auto& bin_file_name = GetOutputModelPath(); if (bin_file_name.empty()) { return {}; } diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 3260d18e9f43c..8f47155d34fa1 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -93,29 +93,6 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, return Status::OK(); } -std::shared_ptr EPCtxHandler::GetSharedContextForEpContextSubgraph(const GraphViewer& subgraph_view, const std::filesystem::path& ep_context_path) const { - if (!CheckForOVEPCtxNodeInGraph(subgraph_view)) { - return nullptr; - } - - auto first_index = *subgraph_view.GetNodesInTopologicalOrder().begin(); - auto node = subgraph_view.GetNode(first_index); - ORT_ENFORCE(node != nullptr); - auto& attrs = node->GetAttributes(); - ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) == 1); - const auto& ep_cache_context = attrs.at(EP_CACHE_CONTEXT).s(); - - ORT_ENFORCE(attrs.count(EMBED_MODE) == 1); - bool embed_mode = static_cast(attrs.at(EMBED_MODE).i()); - - std::filesystem::path bin_path{}; - if (!embed_mode) { - bin_path = ep_context_path.parent_path() / ep_cache_context; - } - - return shared_context_manager_->GetOrCreateSharedContext(bin_path); -} - std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); auto node = graph_viewer.GetNode(first_index); @@ -218,10 +195,12 @@ bool EPCtxHandler::CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, return false; } -void EPCtxHandler::Initialize(const std::vector& fused_nodes, const std::filesystem::path& ep_context_dir) { +std::shared_ptr EPCtxHandler::Initialize(const std::vector& fused_nodes, const SessionContext& session_context) { bool has_embed_nodes = false; bool has_non_embed_nodes = false; bool has_main_context = false; + + std::shared_ptr shared_context{}; for (const auto& fused_node_graph : fused_nodes) { const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; @@ -241,28 +220,29 @@ void EPCtxHandler::Initialize(const std::vector(attrs.at(EMBED_MODE).i()); } - has_embed_nodes |= embed_mode; - has_non_embed_nodes |= !embed_mode; bool main_context = true; if (attrs.count(MAIN_CONTEXT) == 1) { main_context = static_cast(attrs.at(MAIN_CONTEXT).i()); } + has_main_context |= main_context; + has_embed_nodes |= embed_mode; + has_non_embed_nodes |= !embed_mode; const std::string& ep_cache_context = attrs.at(EP_CACHE_CONTEXT).s(); if (embed_mode) { std::filesystem::path dummy_path{}; - auto shared_context = shared_context_manager_->GetOrCreateSharedContext(dummy_path); + shared_context = shared_context_manager_->GetOrCreateSharedContext(dummy_path); if (main_context) { ORT_ENFORCE(!ep_cache_context.empty(), "Embedded EP context is indicated but EP_CACHE_CONTEXT attribute is empty."); std::istringstream ss(ep_cache_context); shared_context->Deserialize(ss); } } else { - std::filesystem::path ep_context_path = ep_context_dir / ep_cache_context; + std::filesystem::path ep_context_path = session_context.GetOutputModelPath().parent_path() / ep_cache_context; if (ep_context_path.extension() != ".xml") { - auto shared_context = shared_context_manager_->GetOrCreateSharedContext(ep_context_path); + shared_context = shared_context_manager_->GetOrCreateSharedContext(ep_context_path); shared_context->Deserialize(); } } @@ -272,6 +252,13 @@ void EPCtxHandler::Initialize(const std::vectorGetOrCreateActiveSharedContext(session_context.GetOutputBinPath()); + } + + return shared_context; } } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index fc2a56c1d0671..fce88005a0605 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -9,14 +9,12 @@ #include "core/providers/shared_library/provider_api.h" #include "core/framework/execution_provider.h" -#include "ov_bin_manager.h" #include "ov_shared_context.h" +#include "contexts.h" namespace onnxruntime { namespace openvino_ep { -class SharedBinManager; - struct ModelBlobWrapper { ModelBlobWrapper(std::unique_ptr stream, const ov::Tensor& tensor) : stream_(std::move(stream)), tensor_(tensor) {} std::unique_ptr stream_; @@ -38,7 +36,6 @@ class EPCtxHandler { EPCtxHandler(std::string ov_sdk_version, const logging::Logger& logger, std::shared_ptr shared_context_manager); EPCtxHandler(const EPCtxHandler&) = delete; // No copy constructor bool CheckForOVEPCtxNodeInGraph(const GraphViewer& subgraph_view) const; - std::shared_ptr GetSharedContextForEpContextSubgraph(const GraphViewer& subgraph_view, const std::filesystem::path& ep_context_path) const; bool CheckForOVEPCtxNode(const Node& node) const; Status AddOVEPCtxNodeToGraph(const GraphViewer& subgraph_view, const std::string& graph_name, @@ -47,7 +44,7 @@ class EPCtxHandler { std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& subgraph_view) const; InlinedVector GetEPCtxNodes() const; bool CheckEPCacheContextAttribute(const GraphViewer& subgraph_view, const std::string& target_attr_extn) const; - void Initialize(const std::vector& fused_nodes, const std::filesystem::path& ep_context_path); + std::shared_ptr Initialize(const std::vector& fused_nodes, const SessionContext& session_context); private: const std::string openvino_sdk_version_; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 6dc7328d696da..a099f85b2a4b9 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -110,22 +110,17 @@ common::Status OpenVINOExecutionProvider::Compile( std::string("Invalid EP context configuration: ") + kOrtSessionOptionEpContextEmbedMode + " must be 0 if " + kOrtSessionOptionShareEpContexts + " is 1."); } - bool is_epctx_model = false; if (!fused_nodes.empty()) { // Assume these properties are constant for all the model subgraphs, otherwise move to SubGraphContext const auto& graph_body_viewer_0 = fused_nodes[0].filtered_graph.get(); session_context_.onnx_model_path_name = graph_body_viewer_0.ModelPath().string(); session_context_.onnx_opset_version = graph_body_viewer_0.DomainToVersionMap().at(kOnnxDomain); - - // OVIR wrapped in epctx should be treated as source but this code does not - // This corner case is not in use and will be addressed in a future commit - is_epctx_model = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(graph_body_viewer_0); } - if (is_epctx_model) { - ep_ctx_handle_.Initialize(fused_nodes, session_context_.GetOutputBinPath().parent_path()); - } + shared_context_ = ep_ctx_handle_.Initialize(fused_nodes, session_context_); + ORT_ENFORCE(shared_context_, + "Failed to create or retrieve SharedContext"); struct OpenVINOEPFunctionState { AllocateFunc allocate_func = nullptr; @@ -145,7 +140,7 @@ common::Status OpenVINOExecutionProvider::Compile( // For original model, check if the user wants to export a model with pre-compiled blob auto& backend_manager = backend_managers_.emplace_back(session_context_, - *shared_context_manager_, + *shared_context_, fused_node, graph_body_viewer, logger, @@ -199,11 +194,9 @@ common::Status OpenVINOExecutionProvider::Compile( // bit clunky ideally we should try to fold this into ep context handler if (!session_context_.so_context_embed_mode) { - auto shared_context = shared_context_manager_->GetOrCreateActiveSharedContext(session_context_.GetOutputBinPath()); - shared_context->Serialize(); + shared_context_->Serialize(); if (session_context_.so_stop_share_ep_contexts) { shared_context_manager_->ClearActiveSharedContext(); - shared_context->Clear(); } } } diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 326f6de30498f..a343ad34cae50 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -81,6 +81,7 @@ class OpenVINOExecutionProvider : public IExecutionProvider { SessionContext session_context_; std::shared_ptr ov_core_; std::shared_ptr shared_context_manager_; + std::shared_ptr shared_context_; std::list backend_managers_; // EP session owns the backend objects EPCtxHandler ep_ctx_handle_; diff --git a/onnxruntime/core/providers/openvino/ov_bin_manager.cc b/onnxruntime/core/providers/openvino/ov_bin_manager.cc index bdab631bb478b..88a50377281bc 100644 --- a/onnxruntime/core/providers/openvino/ov_bin_manager.cc +++ b/onnxruntime/core/providers/openvino/ov_bin_manager.cc @@ -189,13 +189,6 @@ std::unique_ptr BinManager::GetNativeBlobAsStream(const std::strin return std::make_unique(GetNativeBlob(blob_name)); } -void BinManager::Clear() { - std::unique_lock lock(mutex_); - native_blobs_.clear(); - mapped_bin_ = {}; - external_bin_path_.reset(); -} - std::filesystem::path BinManager::GetBinPathForModel(const std::filesystem::path& model_path) { ORT_ENFORCE(!model_path.empty()); return model_path.parent_path() / (model_path.stem().string() + "_" + kOpenVINOExecutionProvider + ".bin"); @@ -215,22 +208,12 @@ void BinManager::Deserialize(std::shared_ptr shared_context) { Deserialize(stream, shared_context); } -bool BinManager::ShouldSerialize(const std::shared_ptr& shared_context) const { - if (shared_context) { - auto metadata = shared_context->GetMetadataCopy(); - if (!metadata.empty()) { - return true; - } - } - return !native_blobs_.empty(); -} - void BinManager::Serialize(std::ostream& stream, std::shared_ptr shared_context) { std::shared_lock ul(mutex_); - if (!ShouldSerialize(shared_context)) { - // nothing to serialize - return; + auto metadata = shared_context ? shared_context->GetMetadataCopy() : SharedContext::Metadata::Map{}; + if (metadata.empty() && native_blobs_.empty()) { + return; // Nothing to serialize } const auto stream_start = stream.tellp(); @@ -259,19 +242,16 @@ void BinManager::Serialize(std::ostream& stream, std::shared_ptr j[BSONFields::kProducer] = BSONFields::kProducerName; // Add weights metadata as a map (from SharedContext if available) - if (shared_context) { - auto metadata = shared_context->GetMetadataCopy(); - if (!metadata.empty()) { - nlohmann::json weights_map = nlohmann::json::object(); - for (const auto& [key, value] : metadata) { - nlohmann::json weight_entry; - weight_entry[BSONFields::kLocation] = value.serialized.location.string(); - weight_entry[BSONFields::kDataOffset] = value.serialized.data_offset; - weight_entry[BSONFields::kSize] = value.serialized.size; - weights_map[key] = weight_entry; - } - j[BSONFields::kWeightsMetadata] = weights_map; + if (!metadata.empty()) { + nlohmann::json weights_map = nlohmann::json::object(); + for (const auto& [key, value] : metadata) { + nlohmann::json weight_entry; + weight_entry[BSONFields::kLocation] = value.serialized.location.string(); + weight_entry[BSONFields::kDataOffset] = value.serialized.data_offset; + weight_entry[BSONFields::kSize] = value.serialized.size; + weights_map[key] = weight_entry; } + j[BSONFields::kWeightsMetadata] = weights_map; } // Add blob metadata with placeholder values as a map (will be updated after writing blobs) @@ -340,6 +320,14 @@ void BinManager::Serialize(std::ostream& stream, std::shared_ptr } void BinManager::Deserialize(std::istream& stream, std::shared_ptr shared_context) { + try { + DeserializeImpl(stream, shared_context); + } catch (const std::exception& e) { + ORT_THROW(e.what(), "\nCould not deserialize binary data. This could mean the bin is corrupted or incompatible. Try re-generating ep context cache."); + } +} + +void BinManager::DeserializeImpl(std::istream& stream, const std::shared_ptr& shared_context) { // Read and validate header header_t header{}; diff --git a/onnxruntime/core/providers/openvino/ov_bin_manager.h b/onnxruntime/core/providers/openvino/ov_bin_manager.h index d6d6ada2d252a..b50cfc460ec96 100644 --- a/onnxruntime/core/providers/openvino/ov_bin_manager.h +++ b/onnxruntime/core/providers/openvino/ov_bin_manager.h @@ -31,7 +31,6 @@ class BinManager { void AddNativeBlob(const std::string& name, const ov::CompiledModel& compiled_model); ov::Tensor GetNativeBlob(const std::string& blob_name); std::unique_ptr GetNativeBlobAsStream(const std::string& blob_name); - void Clear(); // Serialization/Deserialization void Serialize(std::ostream& stream, std::shared_ptr shared_context = nullptr); @@ -65,7 +64,7 @@ class BinManager { } serialized_info; }; - bool ShouldSerialize(const std::shared_ptr& shared_context) const; + void DeserializeImpl(std::istream& stream, const std::shared_ptr& shared_context); mutable std::shared_mutex mutex_; std::optional external_bin_path_; diff --git a/onnxruntime/core/providers/openvino/ov_shared_context.cc b/onnxruntime/core/providers/openvino/ov_shared_context.cc index 84cce6e7e16d4..f48284d0cc974 100644 --- a/onnxruntime/core/providers/openvino/ov_shared_context.cc +++ b/onnxruntime/core/providers/openvino/ov_shared_context.cc @@ -132,14 +132,5 @@ void SharedContext::Deserialize() { bin_manager_.Deserialize(shared_from_this()); } -void SharedContext::Clear() { - // Outside the mutex since bin_manager has it's own lock, and we want to keep lock ordering consistent - // It's ok for clear to not be fully atomic we're primarily interested in internal consistency. - bin_manager_.Clear(); - std::unique_lock lock(mutex_); - weight_files_.clear(); - metadata_.clear(); -} - } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_shared_context.h b/onnxruntime/core/providers/openvino/ov_shared_context.h index c893b64442fa4..aee6d5570d8fa 100644 --- a/onnxruntime/core/providers/openvino/ov_shared_context.h +++ b/onnxruntime/core/providers/openvino/ov_shared_context.h @@ -75,8 +75,6 @@ class SharedContext : public std::enable_shared_from_this { void Serialize(); void Deserialize(); - void Clear(); - std::filesystem::path GetBinPath() const { return bin_manager_.GetExternalBinPath(); } @@ -132,6 +130,7 @@ class SharedContextManager : public WeakSingleton { it->second = std::make_shared(model_path); } active_context_ = it->second; + active_context_path_ = model_path; return it->second; } @@ -146,6 +145,10 @@ class SharedContextManager : public WeakSingleton { void ClearActiveSharedContext() { std::lock_guard lock(mutex_); + if (active_context_) { + contexts_.erase(active_context_path_); + active_context_path_.clear(); + } active_context_ = nullptr; } @@ -153,6 +156,7 @@ class SharedContextManager : public WeakSingleton { mutable std::mutex mutex_; std::unordered_map> contexts_; std::shared_ptr active_context_; + std::filesystem::path active_context_path_; }; } // namespace openvino_ep From 0d68ee5621f1240bca133ed1cd76af8e02163992 Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Tue, 18 Nov 2025 15:13:36 +0530 Subject: [PATCH 129/138] CVS-176574 : Fix memory leaks for protobuf & DataOps (#852) * fix: fix mem leaks * fix linux builds --- cmake/onnxruntime_providers_openvino.cmake | 5 ++ onnxruntime/core/dll/dllmain.cc | 4 ++ .../openvino/openvino_provider_dllmain.cc | 51 +++++++++++++++++++ .../openvino/ov_versions/capability.cc | 8 +-- .../openvino/ov_versions/capability.h | 2 +- 5 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/core/providers/openvino/openvino_provider_dllmain.cc diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index 8c8d58c30f594..882fc56d9a40b 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -33,6 +33,11 @@ source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_openvino_cc_srcs}) onnxruntime_add_shared_library_module(onnxruntime_providers_openvino ${onnxruntime_providers_openvino_cc_srcs} "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc") + # Propagate leak check define if enabled at top level + if(onnxruntime_ENABLE_MEMLEAK_CHECKER) + target_compile_definitions(onnxruntime_providers_openvino PRIVATE ONNXRUNTIME_ENABLE_MEMLEAK_CHECK) + endif() + onnxruntime_add_include_to_target(onnxruntime_providers_openvino onnxruntime_common onnx nlohmann_json::nlohmann_json) install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/openvino/openvino_provider_factory.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) diff --git a/onnxruntime/core/dll/dllmain.cc b/onnxruntime/core/dll/dllmain.cc index 7cc00fa4ca74a..9e50c6e07738f 100644 --- a/onnxruntime/core/dll/dllmain.cc +++ b/onnxruntime/core/dll/dllmain.cc @@ -30,6 +30,10 @@ BOOL APIENTRY DllMain(HMODULE /*hModule*/, if (lpvReserved != nullptr) { g_is_shutting_down = true; // do not do cleanup if process termination scenario +#if defined(ONNXRUNTIME_ENABLE_MEMLEAK_CHECK) + // In leak-check builds we still want protobuf shutdown to avoid flagged leaks. + ::google::protobuf::ShutdownProtobufLibrary(); +#endif } else { // Cleanup protobuf library. // NOTE: it might be too early to do so, as all function local statics and global objects are not destroyed yet. diff --git a/onnxruntime/core/providers/openvino/openvino_provider_dllmain.cc b/onnxruntime/core/providers/openvino/openvino_provider_dllmain.cc new file mode 100644 index 0000000000000..08f9cc065aaae --- /dev/null +++ b/onnxruntime/core/providers/openvino/openvino_provider_dllmain.cc @@ -0,0 +1,51 @@ +// Copyright (c) Intel Corporation. +// Licensed under the MIT License. +#ifdef _WIN32 + +#include +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-qualifiers" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#else +#endif +#include +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif +#include + +// Reuse the global shutdown indicator (do NOT set it here; that is owned by the core DLL). +extern std::atomic g_is_shutting_down; + +// NOTE: +// This DllMain exists because the OpenVINO provider DLL statically links protobuf independently +// of the core onnxruntime DLL. The core DLL's DllMain won't clean up this copy. +// We perform protobuf shutdown on dynamic unload, and (optionally) during process termination +// when memory leak checking is enabled. +BOOL APIENTRY DllMain(HMODULE /*hModule*/, + DWORD ul_reason_for_call, + LPVOID lpvReserved) { + switch (ul_reason_for_call) { + case DLL_PROCESS_ATTACH: + case DLL_THREAD_ATTACH: + case DLL_THREAD_DETACH: + break; + case DLL_PROCESS_DETACH: + // Windows API doc says: "When handling DLL_PROCESS_DETACH, a DLL should free resources such as heap memory only if the DLL is being unloaded dynamically" + if (lpvReserved != nullptr) { + // Process termination. Normally skipped for speed/safety, + // but in leak-check builds we reclaim protobuf heap. +#if defined(ONNXRUNTIME_ENABLE_MEMLEAK_CHECK) + ::google::protobuf::ShutdownProtobufLibrary(); +#endif + } else { + // Dynamic unload: safe to clean up. + ::google::protobuf::ShutdownProtobufLibrary(); + } + break; + } + return TRUE; +} + +#endif // defined(_WIN32) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 1893700cab09c..9185f7a188328 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -42,13 +42,13 @@ GetCapability::GetCapability(const EPCtxHandler& ep_ctx_handler, } #if OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 0 - data_ops_ = new DataOps(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled); + data_ops_ = std::make_unique(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled); #elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 1 - data_ops_ = new DataOps(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled); + data_ops_ = std::make_unique(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled); #elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 2 - data_ops_ = new DataOps(graph_viewer_, V_2025_2, device_type_, npu_qdq_optimizer_enabled); + data_ops_ = std::make_unique(graph_viewer_, V_2025_2, device_type_, npu_qdq_optimizer_enabled); #else - data_ops_ = new DataOps(graph_viewer_, V_2025_2, device_type_, npu_qdq_optimizer_enabled); + data_ops_ = std::make_unique(graph_viewer_, V_2025_2, device_type_, npu_qdq_optimizer_enabled); #endif } diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.h b/onnxruntime/core/providers/openvino/ov_versions/capability.h index 364e79a76f154..3974bdc3b8ff9 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.h +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.h @@ -16,7 +16,7 @@ class GetCapability { const EPCtxHandler& ep_ctx_handler_; const GraphViewer& graph_viewer_; std::string device_type_; - DataOps* data_ops_; + std::unique_ptr data_ops_; bool is_wholly_supported_graph_ = false; bool has_external_weights_ = false; From dbd1ce04d3b9d312d09e6fcbab9ff6c4689a5385 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Tue, 18 Nov 2025 16:02:30 -0800 Subject: [PATCH 130/138] CVS-175504 Fix mixing weight shared and non-shared models (#854) --- .../core/providers/openvino/onnx_ctx_model_helper.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 8f47155d34fa1..60a461f7159f3 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -253,9 +253,14 @@ std::shared_ptr EPCtxHandler::Initialize(const std::vectorGetOrCreateActiveSharedContext(session_context.GetOutputBinPath()); + if (session_context.so_context_enable && session_context.so_share_ep_contexts) { + // We're creating a shared ep context model get or create the active context. + shared_context = shared_context_manager_->GetOrCreateActiveSharedContext(session_context.GetOutputBinPath()); + } else { + shared_context = shared_context_manager_->GetOrCreateSharedContext(session_context.GetOutputBinPath()); + } } return shared_context; From fa7ab09f1df6122c774bbaea9f00fe2d78b05831 Mon Sep 17 00:00:00 2001 From: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Date: Tue, 18 Nov 2025 19:29:22 -0800 Subject: [PATCH 131/138] ovep stateful: Enable explicit slice of prefill logits when NPUW_SLICE_OUT is disabled (#850) * ovep stateful: Enable explicit slice of prefill logits when NPUW_SLICE_OUT is disabled * Update onnxruntime/core/providers/openvino/ov_interface.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> --- .../core/providers/openvino/ov_interface.cc | 74 ++++++++++++++++++- .../core/providers/openvino/ov_interface.h | 6 +- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 446ed098521cb..23be3447b8799 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -374,11 +374,42 @@ void OVInferRequest::Infer() { StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) : OVInferRequest(std::move(infer_request)), target_device(device) { bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); - if (gpu_or_npu) { + + _npu_logits_slice_required = IsNPULogitsSliceRequired(); + + // check if there is input_ids tensors and if the tensor type is int64, + // because logic prefill_use_full_chat_history is only for specific inputs and data type + auto input_ids_opt = FindTensor("input_ids"); + if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() == ov::element::i64) { prefill_use_full_chat_history = true; } } +static inline bool IsNPUWSliceOutEnabled(const ov::CompiledModel& compiled_model) { + auto slice_out_val = compiled_model.get_property("NPUW_SLICE_OUT"); + if (!slice_out_val.empty()) { + if (slice_out_val.is()) { + return (slice_out_val.as() == "YES"); + } else if (slice_out_val.is()) { + return slice_out_val.as(); + } + } + + return false; +} + +bool StatefulOVInferRequest::IsNPULogitsSliceRequired() { + if (target_device.find("NPU") != std::string::npos) { + const auto& model = ovInfReq.get_compiled_model(); + // If NPUW_SLICE_OUT is enabled, it means that it's not required to slice within OVEP. + // Otherwise, if NPUW_SLICE_OUT is NOT enabled, then we need to perform some explicit logit + // slicing in OVEP. + return !IsNPUWSliceOutEnabled(model); + } + + return false; +} + void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type, const std::vector& shape, int32_t fill_value) { ov::Tensor tensor = ov::Tensor(type, shape); @@ -519,5 +550,46 @@ void StatefulOVInferRequest::RewindKVCache(size_t index) { } } } + +OVTensorPtr StatefulOVInferRequest::GetTensor(const std::string& input_name) { + + auto tobj = OVInferRequest::GetTensor(input_name); + + if (_npu_logits_slice_required) { + if (input_name == "logits") { + if (tobj->get_shape().size() != 3) { + ORT_THROW(log_tag + std::format("Expected logits to have shape of rank 3, but it has shape of rank {}", + tobj->get_shape().size())); + } + + // When _npu_logits_slice_required is true, it means that prefill may produce logits of shape: + // [, sequence_length, ] + // (Where 'sequence_length' is number of input tokens to prefill) + // But, ORT GenAI is expecting to receive logits of shape: + // [, 1, ] + // In this case, detect when shape[1] is not 1. When it is, create a slice of shape [, 1, ] + if (tobj->get_shape()[1] > 1) { + return OvExceptionBoundary([&]() { + const ov::Coordinate begin = {0, tobj->get_shape()[1] - 1, 0}; + const ov::Coordinate end = {tobj->get_shape()[0], tobj->get_shape()[1], tobj->get_shape()[2]}; + auto sliced_tensor = ov::Tensor(*tobj, begin, end); + if (sliced_tensor.is_continuous()) { + OVTensorPtr blob = std::make_shared(sliced_tensor); + return blob; + } else { + auto continuous_sliced_tensor = ov::Tensor(sliced_tensor.get_element_type(), sliced_tensor.get_shape()); + sliced_tensor.copy_to(continuous_sliced_tensor); + OVTensorPtr blob = std::make_shared(continuous_sliced_tensor); + return blob; + } + }, + "Could not create sliced logits tensor"); + } + } + } + + return tobj; +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 8a55fdcbd4fb4..8fc28b8885e5d 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -110,7 +110,7 @@ class OVInferRequest { public: uint32_t GetNumInputs(); - OVTensorPtr GetTensor(const std::string& name); + virtual OVTensorPtr GetTensor(const std::string& name); std::string GetInputTensorName(uint32_t index); // Set tensor call infer req tensor if ort_ptr differs from last set ptr. @@ -147,6 +147,7 @@ class StatefulOVInferRequest : public OVInferRequest { void CacheTensor(const std::string& tensor_name, std::vector& cache); void SetTensorFromCache(const std::string& tensor_name, const std::vector& cache_data); std::optional FindTensor(const std::string& tensor_name); + OVTensorPtr GetTensor(const std::string& name) override; private: void PreProcessInferRequest(); @@ -157,6 +158,9 @@ class StatefulOVInferRequest : public OVInferRequest { bool prefill_use_full_chat_history = false; std::vector cached_input_ids; std::vector cached_position_ids; + + bool IsNPULogitsSliceRequired(); + bool _npu_logits_slice_required = false; }; } // namespace openvino_ep From 24c833c97bacb3a5be73c73274a4ebd44c765d6c Mon Sep 17 00:00:00 2001 From: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Date: Wed, 19 Nov 2025 10:52:41 +0530 Subject: [PATCH 132/138] Updating OVEP to support 2025.4.0 (#853) --- .../core/providers/openvino/ov_versions/capability.cc | 10 ++++++---- .../core/providers/openvino/ov_versions/data_ops.cc | 8 ++++---- .../core/providers/openvino/ov_versions/data_ops.h | 4 +++- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 9185f7a188328..c7f92ad4d9be1 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -41,14 +41,16 @@ GetCapability::GetCapability(const EPCtxHandler& ep_ctx_handler, npu_qdq_optimizer_enabled = true; // see data_ops.cc ~615 where we check for int16 types for gpu, this may change to a better approach later } -#if OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 0 - data_ops_ = std::make_unique(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 1 +#if OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 1 data_ops_ = std::make_unique(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled); #elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 2 data_ops_ = std::make_unique(graph_viewer_, V_2025_2, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 3 + data_ops_ = std::make_unique(graph_viewer_, V_2025_3, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 4 + data_ops_ = std::make_unique(graph_viewer_, V_2025_4, device_type_, npu_qdq_optimizer_enabled); #else - data_ops_ = std::make_unique(graph_viewer_, V_2025_2, device_type_, npu_qdq_optimizer_enabled); + data_ops_ = std::make_unique(graph_viewer_, V_2025_4, device_type_, npu_qdq_optimizer_enabled); #endif } diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 4156b45cd638a..373b2121a9b60 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -408,7 +408,7 @@ void DataOps::populate_op_mode_supported() { // populate unsupportedmode_t { - UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, V_2025_2}, + UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, V_2025_2, V_2025_3, V_2025_4}, [this](const Node* node, const InitializedTensorSet&) { // If the Input of ReduceMax op is UINT8, it is rejected (Due to output mismatch) for (size_t i = 0; i < node->InputDefs().size(); i++) { @@ -425,7 +425,7 @@ void DataOps::populate_op_mode_supported() { { UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, - V_2025_2}, + V_2025_2, V_2025_3, V_2025_4}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_args = node->InputDefs(); const auto& input_arg = (input_args.size() > 1) ? input_args[1] : input_args[0]; @@ -445,7 +445,7 @@ void DataOps::populate_op_mode_supported() { { UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, - V_2025_2}, + V_2025_2, V_2025_3, V_2025_4}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze // If axes is an input, then we cannot produce a static graph. @@ -461,7 +461,7 @@ void DataOps::populate_op_mode_supported() { } { UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, - V_2024_6, V_2025_0, V_2025_1, V_2025_2}, + V_2024_6, V_2025_0, V_2025_1, V_2025_2, V_2025_3, V_2025_4}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index 95905e010541e..cf6290ee07921 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -36,7 +36,9 @@ enum versionNum { V_2024_6, V_2025_0, V_2025_1, - V_2025_2 + V_2025_2, + V_2025_3, + V_2025_4 }; using VersionNum = enum versionNum; From a69cbf7c55554fb2d7d59c518f49903dc606f6e4 Mon Sep 17 00:00:00 2001 From: Rajeev Sekar Date: Wed, 19 Nov 2025 12:02:16 +0530 Subject: [PATCH 133/138] CVS-175119-[OVEP] Fixed possibility of array index out of bounds in subgraph partitioning (#838) * added a line to add initializers to be a part of meta_def -> inputs * fixed possible array index out of bound problem which caused some models to fail rather than getting sg partitioned * changed loop logic * reverting to the previous logic to ensure j value is retained and not incremented if append_node == true * updated loop logic --------- Co-authored-by: Preetha Veeramalai --- onnxruntime/core/providers/openvino/ov_versions/capability.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index c7f92ad4d9be1..40036212ca125 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -181,7 +181,7 @@ std::vector> GetCapability::Execute() { omit_subgraph = false; } else if (j < total_clusters - 1) { bool append_node = false; - while (j < total_clusters && !append_node) { + while (j < total_clusters - 1 && !append_node) { j = j + 1; append_node = AddTrivialClusterToNextClusterIfConnected(graph_viewer_, index, connected_clusters[j]); } From 2f7212c394e439233296c5c6f7d761186f68dfcd Mon Sep 17 00:00:00 2001 From: Rajeev Sekar Date: Thu, 20 Nov 2025 19:06:29 +0530 Subject: [PATCH 134/138] skipped failing testcase (MathOpTest.Clip_Default_int64) (#860) * skipped testcase --- onnxruntime/test/providers/cpu/math/clip_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/math/clip_test.cc b/onnxruntime/test/providers/cpu/math/clip_test.cc index c1452ab686279..7a4af4f4f504a 100644 --- a/onnxruntime/test/providers/cpu/math/clip_test.cc +++ b/onnxruntime/test/providers/cpu/math/clip_test.cc @@ -99,7 +99,8 @@ TEST(MathOpTest, Clip_Default_int64) { -5, 9, 82}); // TensorRT does not support Clip opset 12 yet. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + // Skipping for OpenVINO because of the following error: Expected equality of these values: cur_expected[i] Which is: 11 cur_actual[i] Which is: 0 + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, Clip_Default_uint64) { From 5a06f6891a0e9fe60fe851dfd6951130051cc8ba Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Tue, 25 Nov 2025 11:54:48 +0530 Subject: [PATCH 135/138] reset ort.eprp file changes (#862) --- ort.wprp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/ort.wprp b/ort.wprp index 99a5d72e597e7..5dd2332cb1f9f 100644 --- a/ort.wprp +++ b/ort.wprp @@ -17,11 +17,6 @@ - - - - @@ -29,7 +24,6 @@ - From 39d6db58a737449b0bd3ddc9b5a8777f4426d99d Mon Sep 17 00:00:00 2001 From: Jaswanth51 Date: Wed, 3 Dec 2025 11:35:16 +0530 Subject: [PATCH 136/138] Sync with Microsoft ONNX Runtime - 03/12/2025 (#867) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix npm audit vulnerabilities in /js directory (#26632) ### Description Resolved all security vulnerabilities in JavaScript packages under `/js` by running `npm audit fix`. All updates are non-breaking patch/minor version bumps. **Fixed vulnerabilities:** - `/js` root: 1 high severity - `glob` 10.4.5 → 10.5.0 (command injection - GHSA-5j98-mcp5-4vw2) - `/js/react_native`: 7 vulnerabilities (1 high, 3 moderate, 3 low) - `image-size` → 1.2.1 (high: DoS via infinite loop - GHSA-m5qc-5hw7-8vg7) - `@babel/helpers` 7.25.6 → 7.28.4 (moderate: RegExp complexity - GHSA-968p-4wvh-cqc8) - `@babel/runtime` 7.25.6 → 7.28.4 (moderate: RegExp complexity - GHSA-968p-4wvh-cqc8) - `js-yaml` → fixed (moderate: prototype pollution - GHSA-mh29-5h37-fv8m) - `brace-expansion` 2.0.1 → 2.0.2 (low: ReDoS - GHSA-v6h2-p8h4-qcjw) - `on-headers` → fixed (low: header manipulation - GHSA-76c9-3jph-rj3q) **Files modified:** - `js/package-lock.json` - `js/react_native/package-lock.json` **Result:** All JS packages (`/js`, `/js/common`, `/js/web`, `/js/node`, `/js/react_native`) now report 0 vulnerabilities. ### Motivation and Context Security maintenance to address dependency vulnerabilities identified by `npm audit`. No breaking changes or code modifications required.
Original prompt > Please create a pull request that runs `npm audit fix` for the JavaScript/TypeScript portion of the repository under the `/js` directory of [microsoft/onnxruntime](https://github.com/microsoft/onnxruntime). > > Requirements: > > 1. **Scope** > - Work only within the `/js` folder and its subpackages (e.g., `js/web`, `js/node`, `js/common`, etc.). > - Do not modify files outside `/js`. > > 2. **Dependency updates** > - Run `npm audit fix` (and, if necessary to fully resolve high/critical issues while staying non-breaking, `npm audit fix --force` on specific subpackages) to address security vulnerabilities. > - Prefer minimal, non-breaking version bumps (patch and minor) that satisfy `npm audit` while keeping semver ranges sensible. > - If any **major** upgrades are required to clear vulnerabilities, handle them cautiously: > - Apply the upgrade only if tests still pass and typings/build setup remain compatible. > - If a major bump would require code changes or creates breaking behavior, **do not** apply it; instead, leave a TODO comment in the PR description summarizing which packages remain vulnerable and why. > > 3. **Validation** > - Run the existing JS-related checks that the repo supports from `/js`, such as: > - `npm test` or package-specific test scripts. > - Any documented lint/build/test commands for JS packages (e.g., `npm run build`, `npm run lint`) where applicable. > - Ensure the updated lockfiles (if present) are consistent, and the project installs cleanly with `npm ci` (or the repo's documented install command) in the `/js` area. > > 4. **Files to update** > - Update `package.json` and lockfiles under `/js` (e.g., `package-lock.json`, `npm-shrinkwrap.json`, or workspace-specific lock files) to reflect the audited dependency tree. > - Do not manually edit `node_modules`; rely on `npm` to manage dependencies and only commit manifest/lockfile changes. > > 5. **Repository conventions** > - Follow this repo's existing conventions for formatting, commit messages, and JS tooling. > - Keep the diff focused on the dependency and lockfile updates plus any absolutely necessary code tweaks to maintain compatibility. > > 6. **Pull request description** > - In the PR body, include: > - A short summary: that `npm audit fix` was run in `/js` to address dependency vulnerabilities. > - A bullet list of notable dependency changes (especially any major version bumps), with packages and old/new versions. > - A brief testing summary (commands run and their results). > - A note about any remaining vulnerabilities that could not be fixed without breaking changes (if applicable), including the affected packages and advisories if available. > > The goal is a clean, minimal PR that improves the security posture of the JS packages under `/js` in `microsoft/onnxruntime` without introducing breaking changes.
*This pull request was created as a result of the following prompt from Copilot chat.* > Please create a pull request that runs `npm audit fix` for the JavaScript/TypeScript portion of the repository under the `/js` directory of [microsoft/onnxruntime](https://github.com/microsoft/onnxruntime). > > Requirements: > > 1. **Scope** > - Work only within the `/js` folder and its subpackages (e.g., `js/web`, `js/node`, `js/common`, etc.). > - Do not modify files outside `/js`. > > 2. **Dependency updates** > - Run `npm audit fix` (and, if necessary to fully resolve high/critical issues while staying non-breaking, `npm audit fix --force` on specific subpackages) to address security vulnerabilities. > - Prefer minimal, non-breaking version bumps (patch and minor) that satisfy `npm audit` while keeping semver ranges sensible. > - If any **major** upgrades are required to clear vulnerabilities, handle them cautiously: > - Apply the upgrade only if tests still pass and typings/build setup remain compatible. > - If a major bump would require code changes or creates breaking behavior, **do not** apply it; instead, leave a TODO comment in the PR description summarizing which packages remain vulnerable and why. > > 3. **Validation** > - Run the existing JS-related checks that the repo supports from `/js`, such as: > - `npm test` or package-specific test scripts. > - Any documented lint/build/test commands for JS packages (e.g., `npm run build`, `npm run lint`) where applicable. > - Ensure the updated lockfiles (if present) are consistent, and the project installs cleanly with `npm ci` (or the repo's documented install command) in the `/js` area. > > 4. **Files to update** > - Update `package.json` and lockfiles under `/js` (e.g., `package-lock.json`, `npm-shrinkwrap.json`, or workspace-specific lock files) to reflect the audited dependency tree. > - Do not manually edit `node_modules`; rely on `npm` to manage dependencies and only commit manifest/lockfile changes. > > 5. **Repository conventions** > - Follow this repo's existing conventions for formatting, commit messages, and JS tooling. > - Keep the diff focused on the dependency and lockfile updates plus any absolutely necessary code tweaks to maintain compatibility. > > 6. **Pull request description** > - In the PR body, include: > - A short summary: that `npm audit fix` was run in `/js` to address dependency vulnerabilities. > - A bullet list of notable dependency changes (especially any major version bumps), with packages and old/new versions. > - A brief testing summary (commands run and their results). > - A note about any remaining vulnerabilities that could not be fixed without breaking changes (if applicable), including the affected packages and advisories if available. > > The goal is a clean, minimal PR that improves the security posture of the JS packages under `/js` in `microsoft/onnxruntime` without introducing breaking changes. --- ✨ Let Copilot coding agent [set things up for you](https://github.com/microsoft/onnxruntime/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: fs-eire <7679871+fs-eire@users.noreply.github.com> * [webgpu] Optimize InstanceNormalization by removing redundant transpose (#26626) ### Description This PR optimizes `InstanceNormalization` by removing redundant transpose. Given the implementation of `InstanceNormalization` for `NCHW` is more effiencient, we don't need to add wrapper `Transpose` to make it run in `NHWC`, which helps use to elide redundant transpose and improve performance. Testing on Lunar Lake shows about `~60%` performance improvement in `InstanceNormalization` operations. #### `InstanceNormalization` OP benchmark The input tensor shape: `(1,32,1048576)` The scale tensor shape: `(32)` The B tensor shape: `(32)` | time cost (ms) | baseline | opt | diff | | ---------------- | -------- | ---- | ---- | | Lunar Lake | 82.6 | 34.2 | 58% | #### Model benchmark | time cost (ms) | baseline | opt | diff | | ---------------- | -------- | ---- | ---- | | sd-turbo-vae-decoder-fp16-demo | 2437.6 | 1835.9 | 25% | ### Motivation and Context Please see above * [webgpu] refactor a few "context" classes (#26602) ### Description This PR refactors a few "context" classes to make it clearer and support new features. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> * Bump actions/checkout from 5 to 6 (#26641) Bumps [actions/checkout](https://github.com/actions/checkout) from 5 to 6.
Release notes

Sourced from actions/checkout's releases.

v6.0.0

What's Changed

Full Changelog: https://github.com/actions/checkout/compare/v5.0.0...v6.0.0

v6-beta

What's Changed

Updated persist-credentials to store the credentials under $RUNNER_TEMP instead of directly in the local git config.

This requires a minimum Actions Runner version of v2.329.0 to access the persisted credentials for Docker container action scenarios.

v5.0.1

What's Changed

Full Changelog: https://github.com/actions/checkout/compare/v5...v5.0.1

Changelog

Sourced from actions/checkout's changelog.

Changelog

V6.0.0

V5.0.1

V5.0.0

V4.3.1

V4.3.0

v4.2.2

v4.2.1

v4.2.0

v4.1.7

v4.1.6

v4.1.5

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/checkout&package-manager=github_actions&previous-version=5&new-version=6)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * add LogEvaluationStart for ReplayGraph (#26645) ### Description add LogEvaluationStart for ReplayGraph to match LogEvaluationStop ### Motivation and Context So by using ETW, could capture run time correctly Co-authored-by: hualxie * add LogCompileModel to mark the session usage (#26646) ### Description add LogCompileModel to mark the session usage as Compile because that session will not be used for inference We could also use it to log compile model parameters if needed ### Motivation and Context We are building a profiling tool for WinML and we want to differentiate Compile session and inference session. I think there are two ways to do it but I don't know which is better https://github.com/microsoft/onnxruntime/pull/26646 https://github.com/microsoft/onnxruntime/pull/26647 --------- Co-authored-by: hualxie * [webgpu] Fix bug introduced by RoE (#26661) Fix bug introduced by #26563 which used the wrong condition by accident and results incorrect result in graph capture mode. * [QNN-EP] Enable verbose and artifacts saving in onnxruntime_provider_test.exe (#26396) ### Description - The change allows users to better debug unit tests by adding the following environment variables: - `QNN_DUMP_ONNX`: Dump input onnx model - `QNN_DUMP_JSON`: Dump json qnn graph with provider_option `dump_json_qnn_graph` - `QNN_DUMP_DLC`: Dump dlc with provider_option `qnn_ir_backend_path` - `QNN_VERBOSE`: Use the log level `ORT_LOGGING_LEVEL_VERBOSE` - Developers can use the environment variables above to save the artifacts of QNN-EP testcases to a directory named with `_` ``` . ├── QnnCPUBackendTests_BatchNorm2D_fp32 # RunQnnModelTest │ ├── dumped_f32_model.onnx # float32 ONNX model │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc │ └── QNNExecutionProvider_QNN_XXXX_X_X.json ├── QnnHTPBackendTests_BatchNorm_FP16 # TestFp16ModelAccuracy │ ├── dumped_f16_model.onnx # float16 ONNX model │ ├── dumped_f32_model.onnx # float32 ONNX model │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc │ └── QNNExecutionProvider_QNN_XXXX_X_X.json └── QnnHTPBackendTests_BatchNorm2D_U8U8S32 # TestQDQModelAccuracy ├── dumped_f32_model.onnx # float32 ONNX model ├── dumped_qdq_model.onnx # QDQ ONNX model ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc └── QNNExecutionProvider_QNN_XXXX_X_X.json # All artifact files are placed under the current working directory from which the test binary is invoked. ``` ### Motivation and Context - The Json qnn graph/dlc are helpful for backend to debug performance/accuracy issues - By comparing the onnx and Json qnn graph/dlc, we can locate the issue about graph manipulation. * [webgpu] Use multiplication instead of pow if exponent is 2 (#26667) ### Description More accurately compute Pow(2.0) on WebGPU EP. Reproduction script: ```py from onnx import helper, TensorProto import onnxruntime as ort import numpy as np # 1. Create the ONNX model # Define input and output input_info = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1]) output_info = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1]) # Create a constant tensor for the exponent (2.0) exponent_tensor = helper.make_tensor('exponent', TensorProto.FLOAT, [], [2.0]) exponent_node = helper.make_node('Constant', [], ['exponent_out'], value=exponent_tensor) # Create the Pow node # Pow takes two inputs: Base (X) and Power (exponent_out) pow_node = helper.make_node( 'Pow', inputs=['X', 'exponent_out'], outputs=['Y'], name='PowNode' ) # Create the graph graph_def = helper.make_graph( [exponent_node, pow_node], 'test-model', [input_info], [output_info] ) # Create the model model_def = helper.make_model(graph_def, producer_name='onnx-example') opset = model_def.opset_import[0] opset.version = 13 # Ensure opset version supports the operations # 2. Convert model to string (bytes) model_str = model_def.SerializeToString() # 3. Prepare input data np.random.seed(0) input_data = np.array([[-2e3]], dtype=np.float32) # 4. Run on CPUExecutionProvider sess_cpu = ort.InferenceSession(model_str, providers=['CPUExecutionProvider']) res_cpu = sess_cpu.run(['Y'], {'X': input_data})[0] print("CPU Result:", res_cpu) # 5. Run on WebGpuExecutionProvider sess_webgpu = ort.InferenceSession(model_str, providers=['WebGpuExecutionProvider']) res_webgpu = sess_webgpu.run(['Y'], {'X': input_data})[0] print("WebGPU Result:", res_webgpu) # Compare results diff = np.abs(res_cpu - res_webgpu) max_diff = diff.max().item() assert max_diff < 1e-5, f"Results do not match within tolerance! Max diff: {max_diff}" print("Results match!") ``` currently produces ``` CPU Result: [[4.e+06]] WebGPU Result: [[3.999999e+06]] --------------------------------------------------------------------------- AssertionError Traceback (most recent call last) Cell In[1], [line 56](vscode-notebook-cell:?execution_count=1&line=56) 54 diff = np.abs(res_cpu - res_webgpu) 55 max_diff = diff.max().item() ---> [56](vscode-notebook-cell:?execution_count=1&line=56) assert max_diff < 1e-5, f"Results do not match within tolerance! Max diff: {max_diff}" 57 print("Results match!") AssertionError: Results do not match within tolerance! Max diff: 1.0 ``` but with this PR: ``` CPU Result: [[4.e+06]] WebGPU Result: [[4.e+06]] Results match! ``` ### Motivation and Context Leads to downstream issues/inaccuracies for certain models, especially those which have larger values to compute pow(x,2) for. cc @guschmue * Avoid creation of temporary protobuf object (#26681) ### Description While profiling session creation time for large graphs (number of nodes, not size of tensors), we noticed that the creations and subsequent destructions of protobuf objects were the major hotspot. This PR avoids its creation. Signed-off-by: Christian Bourjau * Use `std::string_view` directly as key to `absl::flat_hash_map::find` (#26682) ### Description Use `std::string_view` directly as key in `find` method of `flat_hash_map`. This part of the absl documentation may provide further insights: https://abseil.io/docs/cpp/guides/container#heterogeneous-lookup ### Motivation and Context We noticed this when profiling the session creation of large models (in terms of the number of nodes). Signed-off-by: Christian Bourjau * [webgpu] Convert i32 to u32 in uniforms (#26676) In debug mode, `webgpu_context.cc:257 Run Uniform variable[5] (head_size) data type mismatch in program "SplitPackedQKVWithRotaryEmbeddingAndCopyKV", Expected: u32, Actual: i32`. No issue in release mode. Convert i32 to u32 to avoid this issue. * [webgpu] Fix BatchNormalization ShapeInferenceError for 2D inputs (#26659) ### Description Test model (happens with any 2D inputs): [2191__visual_projection_visual_projection.1_BatchNormalization.onnx.zip](https://github.com/user-attachments/files/23758390/2191__visual_projection_visual_projection.1_BatchNormalization.onnx.zip) Command: ``` python -c "import onnxruntime as ort; ort.InferenceSession('2191__visual_projection_visual_projection.1_BatchNormalization.onnx', providers=['WebGpuExecutionProvider'])" ``` Before (failure): ``` Op (BatchNormalization) [ShapeInferenceError] Tensor must have at least 3 dimensions to convert between channels first and channels last. ``` After (success): ``` (nothing, meaning success) ``` ### Motivation and Context This fixes BatchNormalization on WebGPU, matching CPU version. cc @guschmue * Clear cuda error on unsupported CudaMemPool test (#26629) ### Description CudaMemPool test checks if it is supported in a given environment. We need to clear the error not to affect subsequent tests. ### Motivation and Context Potential test failure. * [QNN-EP] Include detailed error message in the returned status (#26546) ### Description The original error message only shows: "Failed to setup QNN input tensors for graph: " This change adds more detailed error information by logging the failure reason from [SetupTensors](https://github.com/microsoft/onnxruntime/blob/ea55c160a36d658eae61a4c7aeda6cb55dd54dec/onnxruntime/core/providers/qnn/builder/qnn_model.cc#L386), making it easier to debug issues. ### Motivation and Context User requires detailed error logging for the ORT online context binary generation. * add support for int32_t in webgpu / slice (#26693) fix for https://github.com/microsoft/onnxruntime/issues/26690 * [webgpu] Remove `global_id` and `workgroup_id` in gemm_utils.cc (#26662) ### Description This patch replaces `global_id` and `workgroup_id` with `logical_global_id` and `logical_workgroup_id` which are computed from `workgroup_idx` and the dispatch workgroup sizes set in `ProgramBase::SetDispatchGroupSize()`. ### Motivation and Context We shouldn't use `global_id` or `workgroup_id` directly because the dispatch workgroup sizes may be normalized in `ProgramManager::NormalizeDispatchGroupSize()`. * [webgpu] Correct definition of large numbers, fixes softmax(max_negative_number) in float32 (#26670) ### Description The correct definition of the most negative number is `-3.40282346638528e+38`, according to IEEE 754, but it is being incorrectly registered inline as a truncated version `-3.402823e+38f`. ```py >>> import numpy as np >>> np.finfo(np.float32).min np.float32(-3.4028235e+38) >>> np.finfo(np.float32).min.item() -3.4028234663852886e+38 ``` For this reason, values less than this threshold were handled incorrectly. While this may seem like a small/irrelevant detail, it's essential in attention masking, where we do in fact use this value, leading to large numerical errors down the line. Reproduction: ```py from onnx import helper, TensorProto import onnxruntime as ort import numpy as np # 1. Create the ONNX model # Define input and output input_shape = [1, 2] input_info = helper.make_tensor_value_info('X', TensorProto.FLOAT, input_shape) output_info = helper.make_tensor_value_info('Y', TensorProto.FLOAT, input_shape) # Create the Softmax node # Softmax takes one input: X softmax_node = helper.make_node( 'Softmax', inputs=['X'], outputs=['Y'], name='SoftmaxNode', axis=-1 # Default axis is -1, usually applied to the last dimension ) # Create the graph graph_def = helper.make_graph( [softmax_node], 'test-model', [input_info], [output_info] ) # Create the model model_def = helper.make_model(graph_def, producer_name='onnx-example') opset = model_def.opset_import[0] opset.version = 13 # Ensure opset version supports the operations # 2. Convert model to string (bytes) model_str = model_def.SerializeToString() # 3. Prepare input data np.random.seed(0) input_data = np.array( [[-3.40282346638528e+38, -3.40282346638528e+38]] # [[-3.4028234663852886e+38, -3.4028234663852886e+38]] ).astype(np.float32) print(input_data.tolist()) # 4. Run on CPUExecutionProvider sess_cpu = ort.InferenceSession(model_str, providers=['CPUExecutionProvider']) res_cpu = sess_cpu.run(['Y'], {'X': input_data})[0] print("CPU Result:", res_cpu) # 5. Run on WebGpuExecutionProvider sess_webgpu = ort.InferenceSession(model_str, providers=['WebGpuExecutionProvider']) res_webgpu = sess_webgpu.run(['Y'], {'X': input_data})[0] print("WebGPU Result:", res_webgpu) # Compare results diff = np.abs(res_cpu - res_webgpu) max_diff = diff.max().item() print(diff) print(f"Max diff: {max_diff}") assert max_diff < 1e-5, f"Results do not match within tolerance! Max diff: {max_diff}" print("Results match!") ``` Before: ``` [[-3.4028234663852886e+38, -3.4028234663852886e+38]] CPU Result: [[0.5 0.5]] WebGPU Result: [[0. 0.]] [[0.5 0.5]] Max diff: 0.5 AssertionError: Results do not match within tolerance! Max diff: 0.5 ``` After: ``` [[-3.4028234663852886e+38, -3.4028234663852886e+38]] CPU Result: [[0.5 0.5]] WebGPU Result: [[0.5 0.5]] [[0. 0.]] Max diff: 0.0 Results match! ``` cc @guschmue * [TRT/TRT RTX EP] Fix bug for missing outputs in the returning ComputeCapability/IndexedSubGraph (#26444) ### Description For TRT EP's `GetCapability()`, in some case, the `GetSubGraph()` won't add graph's output to the `ComputeCapability/IndexedSubGraph` returning to ORT. The issue if from following code: ````c++ ... if (node->GetOutputEdgesCount() > node->OutputDefs().size()) { ... // execute here } else { ... if (graph_output_names.find(output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; // missing this } } ```` Update TRT RTX EP as well. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/25373 * [ROCM] Remove docker, contrib ops, ci scripts related to ROCM EP (#26697) ### Description This is follow up of https://github.com/microsoft/onnxruntime/pull/25181 to remove ROCM EP related files to avoid confusion. Documents will be updated later. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/26692 --------- Signed-off-by: dependabot[bot] Signed-off-by: Christian Bourjau Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: fs-eire <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Wenqin Yang Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: xieofxie Co-authored-by: hualxie Co-authored-by: Jiajia Qin Co-authored-by: qti-hungjuiw Co-authored-by: Joshua Lochner Co-authored-by: Christian Bourjau Co-authored-by: Xiaofei Han Co-authored-by: Dmitri Smirnov Co-authored-by: chunghow-qti Co-authored-by: Guenther Schmuelling Co-authored-by: Jiawei Shao Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: Tianlei Wu --- .github/workflows/android.yml | 6 +- .github/workflows/cffconvert.yml | 2 +- .github/workflows/codeql.yml | 2 +- .../workflows/gradle-wrapper-validation.yml | 2 +- .github/workflows/ios.yml | 2 +- .github/workflows/lint.yml | 8 +- .../linux-wasm-ci-build-and-test-workflow.yml | 2 +- .github/workflows/linux_cuda_ci.yml | 2 +- .github/workflows/linux_minimal_build.yml | 20 +- .github/workflows/linux_tensorrt_ci.yml | 2 +- .github/workflows/mac.yml | 4 +- .../macos-ci-build-and-test-workflow.yml | 2 +- .github/workflows/pr_checks.yml | 2 +- .github/workflows/publish-c-apidocs.yml | 2 +- .github/workflows/publish-csharp-apidocs.yml | 2 +- .github/workflows/publish-java-apidocs.yml | 2 +- .github/workflows/publish-js-apidocs.yml | 2 +- .../workflows/publish-objectivec-apidocs.yml | 2 +- .github/workflows/publish-python-apidocs.yml | 2 +- .github/workflows/react_native.yml | 8 +- .github/workflows/reusable_linux_build.yml | 2 +- .github/workflows/web.yml | 2 +- .github/workflows/windows-web-ci-workflow.yml | 2 +- .github/workflows/windows_build_x64_asan.yml | 2 +- .github/workflows/windows_cuda.yml | 4 +- .github/workflows/windows_dml.yml | 2 +- .github/workflows/windows_openvino.yml | 2 +- .github/workflows/windows_qnn_x64.yml | 2 +- .github/workflows/windows_tensorrt.yml | 4 +- .github/workflows/windows_webgpu.yml | 6 +- .../windows_x64_debug_build_x64_debug.yml | 2 +- .../windows_x64_release_build_x64_release.yml | 2 +- ...build_x64_release_ep_generic_interface.yml | 2 +- ..._x64_release_vitisai_build_x64_release.yml | 2 +- .../workflows/windows_x64_release_xnnpack.yml | 2 +- .github/workflows/windows_x86.yml | 2 +- dockerfiles/Dockerfile.rocm | 24 - dockerfiles/README.md | 17 +- dockerfiles/scripts/install_rocm_deps.sh | 84 -- js/package-lock.json | 144 +-- js/react_native/package-lock.json | 125 ++- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/softmax.ts | 2 +- .../contrib_ops/rocm/bert/attention.cu | 215 ---- onnxruntime/contrib_ops/rocm/bert/attention.h | 33 - .../contrib_ops/rocm/bert/attention_impl.cu | 435 --------- .../contrib_ops/rocm/bert/attention_impl.h | 180 ---- .../contrib_ops/rocm/bert/attention_softmax.h | 465 --------- .../bert/batched_gemm_permute_pipelines.cuh | 125 --- .../impl.cuh | 177 ---- .../impl_fp16.cu | 60 -- .../impl_fp16_biased.cu | 60 -- .../impl_fp16_biased_biased.cu | 60 -- ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 915 ------------------ .../rocm/bert/decoder_attention_impl.h | 46 - .../contrib_ops/rocm/bert/elementwise.h | 84 -- .../rocm/bert/elementwise_impl/impl.cuh | 256 ----- .../bert/elementwise_impl/impl_fastgelu.cu | 9 - .../rocm/bert/elementwise_impl/impl_gelu.cu | 9 - .../rocm/bert/elementwise_impl/impl_relu.cu | 8 - .../contrib_ops/rocm/bert/gemm_fast_gelu.cc | 75 -- .../contrib_ops/rocm/bert/gemm_fast_gelu.h | 23 - .../rocm/bert/gemm_fast_gelu_ck.cuh | 133 --- .../rocm/bert/gemm_fast_gelu_common.h | 47 - .../rocm/bert/gemm_fast_gelu_impl.cu | 91 -- .../rocm/bert/gemm_fast_gelu_impl.h | 40 - .../rocm/bert/gemm_fast_gelu_tunable.cuh | 83 -- .../rocm/bert/group_query_attention.cu | 530 ---------- .../rocm/bert/group_query_attention.h | 38 - .../contrib_ops/rocm/bert/layer_norm.cuh | 270 ------ .../rocm/bert/multihead_attention.cu | 286 ------ .../rocm/bert/multihead_attention.h | 51 - .../contrib_ops/rocm/bert/skip_layer_norm.cc | 132 --- .../contrib_ops/rocm/bert/skip_layer_norm.h | 26 - .../rocm/bert/skip_layer_norm_impl.cu | 86 -- .../rocm/bert/skip_layer_norm_impl.h | 31 - .../rocm/bert/skip_layer_norm_impl_kernel.h | 162 ---- .../rocm/bert/skip_layer_norm_tunable_op.h | 161 --- .../rocm/bert/transformer_common.cc | 37 - .../rocm/bert/transformer_common.h | 46 - .../rocm/diffusion/group_norm_ck.cuh | 105 -- .../diffusion/group_norm_ck_impl/impl.cuh | 130 --- .../diffusion/group_norm_ck_impl/impl_fp16.cu | 39 - .../diffusion/group_norm_ck_impl/impl_fp32.cu | 39 - .../rocm/diffusion/group_norm_common.h | 56 -- .../rocm/diffusion/group_norm_impl.cu | 76 -- .../rocm/diffusion/group_norm_triton.cuh | 105 -- .../rocm/diffusion/group_norm_triton.py | 135 --- .../rocm/diffusion/group_norm_tunable_op.h | 220 ----- .../contrib_ops/rocm/diffusion/nhwc_conv.cc | 27 - onnxruntime/contrib_ops/rocm/fused_conv.cc | 439 --------- .../contrib_ops/rocm/math/gemm_float8.cu | 213 ---- .../contrib_ops/rocm/math/gemm_float8_ck.cuh | 276 ------ .../math/gemm_float8_ck_impl/add_instance.cu | 124 --- ...xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu | 97 -- ...k_f16_f8_f16_mk_kn_mn_instance_original.cu | 80 -- ...xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu | 94 -- ...k_f8_f16_f16_mk_kn_mn_instance_original.cu | 97 -- .../contrib_ops/rocm/rocm_contrib_kernels.cc | 347 ------- .../contrib_ops/rocm/rocm_contrib_kernels.h | 14 - .../contrib_ops/webgpu/bert/attention.cc | 6 +- .../webgpu/bert/flash_attention.cc | 6 +- .../webgpu/bert/flash_attention.wgsl.template | 2 +- .../flash_attention_decode_qkt.wgsl.template | 2 +- ...sh_attention_decode_split_vx.wgsl.template | 2 +- .../webgpu/bert/group_query_attention.cc | 4 +- .../contrib_ops/webgpu/moe/gate.wgsl.template | 2 +- .../core/framework/allocation_planner.cc | 3 +- .../core/framework/ort_value_name_idx_map.h | 2 +- .../contrib_ops/nhwc_inference_context.h | 7 +- onnxruntime/core/platform/telemetry.cc | 4 + onnxruntime/core/platform/telemetry.h | 2 + .../core/platform/windows/telemetry.cc | 14 + onnxruntime/core/platform/windows/telemetry.h | 2 + .../core/providers/js/operators/unary.cc | 2 +- .../nv_tensorrt_rtx/nv_execution_provider.cc | 77 +- .../qnn/builder/opbuilder/base_op_builder.cc | 3 + .../core/providers/qnn/builder/qnn_def.cc | 4 + .../core/providers/qnn/builder/qnn_def.h | 2 + .../core/providers/qnn/builder/qnn_model.cc | 4 +- .../tensorrt/tensorrt_execution_provider.cc | 77 +- .../vsinpu/builders/impl/clip_op_builder.cc | 4 +- .../core/providers/webgpu/allocator.cc | 2 +- onnxruntime/core/providers/webgpu/allocator.h | 5 + .../core/providers/webgpu/compute_context.cc | 23 +- .../core/providers/webgpu/compute_context.h | 103 +- .../webgpu/math/binary_elementwise_ops.cc | 11 +- .../core/providers/webgpu/math/gemm_packed.cc | 15 +- .../core/providers/webgpu/math/gemm_packed.h | 5 +- .../core/providers/webgpu/math/gemm_utils.cc | 46 +- .../core/providers/webgpu/math/matmul.cc | 4 +- .../providers/webgpu/math/matmul_packed.h | 5 +- .../core/providers/webgpu/math/softmax.cc | 2 +- onnxruntime/core/providers/webgpu/nn/conv.cc | 40 + onnxruntime/core/providers/webgpu/nn/conv.h | 7 + .../core/providers/webgpu/nn/conv2d_mm.cc | 5 +- .../core/providers/webgpu/nn/conv2d_mm.h | 5 +- .../core/providers/webgpu/tensor/slice.cc | 22 +- .../core/providers/webgpu/tensor/transpose.cc | 2 +- .../core/providers/webgpu/tensor/transpose.h | 2 +- .../core/providers/webgpu/webgpu_context.cc | 18 +- .../core/providers/webgpu/webgpu_context.h | 21 +- .../webgpu/webgpu_execution_provider.cc | 14 +- .../core/providers/webgpu/webgpu_kernel.cc | 47 +- .../core/providers/webgpu/webgpu_kernel.h | 33 + .../core/providers/webgpu/webgpu_utils.cc | 15 +- .../core/providers/webgpu/webgpu_utils.h | 5 +- onnxruntime/core/session/inference_session.cc | 2 + onnxruntime/core/session/utils.cc | 1 + .../providers/cuda/cuda_mempool_arena_test.cc | 15 +- .../nv_tensorrt_rtx/nv_basic_test.cc | 42 + onnxruntime/test/providers/qnn/README.md | 70 ++ .../test/providers/qnn/qnn_test_utils.cc | 60 ++ .../test/providers/qnn/qnn_test_utils.h | 147 ++- .../providers/tensorrt/tensorrt_basic_test.cc | 49 +- .../test/testdata/node_output_not_used.onnx | Bin 0 -> 189 bytes .../test/testdata/node_output_not_used.py | 43 + .../topk_and_multiple_graph_outputs.onnx | Bin 0 -> 393 bytes .../topk_and_multiple_graph_outputs.py | 78 ++ .../github/linux/build_rocm_c_api_package.sh | 40 - .../docker/scripts/setup_rocm_yum_repo.sh | 43 - 161 files changed, 1176 insertions(+), 8816 deletions(-) delete mode 100644 dockerfiles/Dockerfile.rocm delete mode 100644 dockerfiles/scripts/install_rocm_deps.sh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/attention.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/attention.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_impl.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_impl.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_softmax.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/group_query_attention.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/multihead_attention.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h delete mode 100644 onnxruntime/contrib_ops/rocm/bert/transformer_common.cc delete mode 100644 onnxruntime/contrib_ops/rocm/bert/transformer_common.h delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc delete mode 100644 onnxruntime/contrib_ops/rocm/fused_conv.cc delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8.cu delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu delete mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu delete mode 100644 onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc delete mode 100644 onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h create mode 100644 onnxruntime/test/providers/qnn/README.md create mode 100644 onnxruntime/test/testdata/node_output_not_used.onnx create mode 100644 onnxruntime/test/testdata/node_output_not_used.py create mode 100644 onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx create mode 100644 onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py delete mode 100755 tools/ci_build/github/linux/build_rocm_c_api_package.sh delete mode 100755 tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 7f7ff74959d52..f12eadc2ce794 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -27,7 +27,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -112,7 +112,7 @@ jobs: android_nnapi_ep: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Use jdk 17 uses: actions/setup-java@v5 @@ -187,7 +187,7 @@ jobs: name: Android CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Use jdk 17 uses: actions/setup-java@v5 diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index 30f832f67c5ee..ddf4a52a0ccb0 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -12,7 +12,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Check out a copy of the repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@2.0.0 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index d33e4d923a0bc..1db84400c272a 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 04177b11e9c30..d8f13d13d3f88 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -15,7 +15,7 @@ jobs: name: "Validation" runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: gradle/actions/wrapper-validation@v5 concurrency: group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} diff --git a/.github/workflows/ios.yml b/.github/workflows/ios.yml index 0d2046b980783..ed572aa339ce9 100644 --- a/.github/workflows/ios.yml +++ b/.github/workflows/ios.yml @@ -20,7 +20,7 @@ jobs: runs-on: macos-14 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5aaab5f8e1a10..5c618dc5787a5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: name: Optional Lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: misspell # Check spellings as well uses: reviewdog/action-misspell@v1 with: @@ -42,7 +42,7 @@ jobs: contents: read security-events: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 with: @@ -87,7 +87,7 @@ jobs: name: Optional Lint C++ runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Update PATH run: | echo "$HOME/.local/bin" >> "$GITHUB_PATH" @@ -116,7 +116,7 @@ jobs: name: Lint JavaScript runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-node@v6 with: node-version: 20 diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 2370c631b7a7a..5763b9c39bcc6 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -49,7 +49,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: recursive diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index 886705471b7de..e7e3be8c5f9ed 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -48,7 +48,7 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index af86975ee6cdc..4d9579a746892 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -28,7 +28,7 @@ jobs: packages: write steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -65,7 +65,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -122,7 +122,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -156,7 +156,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -188,7 +188,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -222,7 +222,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 @@ -286,7 +286,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -363,7 +363,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -430,7 +430,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -505,7 +505,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false - uses: actions/setup-node@v6 diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 0e26576829e94..47b7c1ba7e889 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -48,7 +48,7 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 # --- Build the Docker image needed for testing --- - name: Build Docker Image for Testing diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index e545406d8d20f..8ba87bc1f731c 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -76,7 +76,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' @@ -124,7 +124,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index 329584c68d7d1..8e1d0264496f6 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -75,7 +75,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index abe627f4ff7bc..7ca330742f69b 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -24,7 +24,7 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 with: diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index 25b7899584bbf..d9fb72271967f 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate C/C++ API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install doxygen and dependencies run: | sudo apt update diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index 34b9c1af9552f..dd55bbd917337 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -24,7 +24,7 @@ jobs: env: DOCFXVERSION: 2.62.2 steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install DocFX run: | dotnet tool update -g docfx diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 656d0627ed17d..81defeae518a3 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Java docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Set up JDK 11 uses: actions/setup-java@v5 with: diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index e71d3b3c57a4b..9da78d7d9ed9c 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate JS API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Node.js uses: actions/setup-node@v6 with: diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index 983d3d478a49d..a73b62eba6050 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Objective-C API docs runs-on: macos-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index 389d1683fb1ff..e35e6a04adbef 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate Python API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install tools run: | sudo apt-get update diff --git a/.github/workflows/react_native.yml b/.github/workflows/react_native.yml index 343186b1aec8c..4a56dfbd35406 100644 --- a/.github/workflows/react_native.yml +++ b/.github/workflows/react_native.yml @@ -20,7 +20,7 @@ jobs: aar_path: ${{ runner.temp }}/.artifacts steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false @@ -75,7 +75,7 @@ jobs: run: echo "ANDROID_AVD_HOME=${{ runner.temp }}/android-avd" >> $GITHUB_ENV - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Use Python 3.12 uses: actions/setup-python@v6 @@ -175,7 +175,7 @@ jobs: timeout-minutes: 120 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Use Xcode 15.3.0 run: sudo xcode-select --switch /Applications/Xcode_15.3.0.app/Contents/Developer @@ -218,7 +218,7 @@ jobs: timeout-minutes: 90 steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Download iOS pod artifact uses: actions/download-artifact@v6 diff --git a/.github/workflows/reusable_linux_build.yml b/.github/workflows/reusable_linux_build.yml index 795e35b06bfb0..f0da87647b8b0 100644 --- a/.github/workflows/reusable_linux_build.yml +++ b/.github/workflows/reusable_linux_build.yml @@ -75,7 +75,7 @@ jobs: id-token: write steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Set up Python ${{ inputs.python_version }} if: inputs.architecture != 'arm64' diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml index 016feab5e0d94..6ae25ccc0bf3e 100644 --- a/.github/workflows/web.yml +++ b/.github/workflows/web.yml @@ -22,7 +22,7 @@ jobs: commit_sha: ${{ steps.extract_commit.outputs.commit_sha }} steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: true diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index eee98332056f6..c16ce6eb222eb 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_build_x64_asan.yml b/.github/workflows/windows_build_x64_asan.yml index 05fd4acd4de9a..ac5f08717155f 100644 --- a/.github/workflows/windows_build_x64_asan.yml +++ b/.github/workflows/windows_build_x64_asan.yml @@ -19,7 +19,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index fd5b65eb039a3..5d6e9b1da31a2 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -21,7 +21,7 @@ jobs: name: Windows GPU CUDA CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' @@ -152,7 +152,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index e8ee7751348b4..0abf6b650f986 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -27,7 +27,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 # Fetch all history for all tags and branches submodules: 'none' diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index b608c0879aa45..537ff1fb00071 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml index 4f0b50e65df6e..f6176164354bb 100644 --- a/.github/workflows/windows_qnn_x64.yml +++ b/.github/workflows/windows_qnn_x64.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index 229efb01f0018..4a564a3b1cb36 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -21,7 +21,7 @@ jobs: name: Windows GPU TensorRT CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' @@ -157,7 +157,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 899a8b66eac7a..f729cda5ea576 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -34,7 +34,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none @@ -156,7 +156,7 @@ jobs: timeout-minutes: 300 steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none @@ -209,7 +209,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index d62c7130e0ebb..385d03c1a6705 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index a2991bb0f1131..ee045b70b6efa 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index bb6c5035b0dce..25dfc41e6922c 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index 4378231338673..e738db262f3a2 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index b453cd570ac05..5672e4043c624 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index d20778d56f60b..381d9dda5cd42 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: submodules: false diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm deleted file mode 100644 index aca8c3feaff71..0000000000000 --- a/dockerfiles/Dockerfile.rocm +++ /dev/null @@ -1,24 +0,0 @@ -# -------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------- -# Dockerfile to run ONNXRuntime with ROCm integration -#-------------------------------------------------------------------------- - -FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 - -ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime -ARG ONNXRUNTIME_BRANCH=main - -WORKDIR /code - -ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} - -# Prepare onnxruntime repository & build onnxruntime -RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ - /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ - cd onnxruntime &&\ - /bin/sh ./build.sh --allow_running_as_root --config Release --build_wheel --update --build --parallel --cmake_extra_defines\ - ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm &&\ - pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ - cd .. diff --git a/dockerfiles/README.md b/dockerfiles/README.md index 4c69098103edd..88c542b63ccd2 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -1,9 +1,8 @@ # Dockerfiles **Execution Providers** - CPU: [Dockerfile](Dockerfile.source), [Instructions](#cpu) -- CUDA/cuDNN: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) +- CUDA: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) - MIGraphX: [Dockerfile](Dockerfile.migraphx), [Instructions](#migraphx) -- ROCm: [Dockerfile](Dockerfile.rocm), [Instructions](#rocm) - OpenVINO: [Dockerfile](Dockerfile.openvino), [Instructions](#openvino) - TensorRT: [Dockerfile](Dockerfile.tensorrt), [Instructions](#tensorrt) - VitisAI: [Dockerfile](Dockerfile.vitisai) @@ -304,17 +303,3 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-migraphx ``` - - ## ROCm -**Ubuntu 22.04, ROCm6.2.3** - -1. Build the docker image from the Dockerfile in this repository. - ``` - docker build -t onnxruntime-rocm -f Dockerfile.rocm . - ``` - -2. Run the Docker image - - ``` - docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-rocm - ``` diff --git a/dockerfiles/scripts/install_rocm_deps.sh b/dockerfiles/scripts/install_rocm_deps.sh deleted file mode 100644 index fd445be87479b..0000000000000 --- a/dockerfiles/scripts/install_rocm_deps.sh +++ /dev/null @@ -1,84 +0,0 @@ -#!/bin/bash -prefix=/opt/rocm -DEBIAN_FRONTEND=noninteractive -apt-get update && apt-get install -y --no-install-recommends \ - wget \ - zip \ - ca-certificates \ - build-essential \ - curl \ - libcurl4-openssl-dev \ - libssl-dev \ - python3-dev - -# rocm-cmake -rocm_cmake_version=4.5.2 -wget --quiet https://github.com/RadeonOpenCompute/rocm-cmake/archive/refs/tags/rocm-${rocm_cmake_version}.tar.gz -tar -xzvf rocm-${rocm_cmake_version}.tar.gz -rm rocm-${rocm_cmake_version}.tar.gz -cd rocm-cmake-rocm-${rocm_cmake_version} -mkdir build -cd build -cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocm-cmake-rocm-${rocm_cmake_version} - -# rccl -rccl_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rccl/archive/refs/tags/rocm-${rccl_version}.tar.gz -tar -xzvf rocm-${rccl_version}.tar.gz -rm rocm-${rccl_version}.tar.gz -cd rccl-rocm-${rccl_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rccl-rocm-${rccl_version} - -#rocrand -rocrand_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rocRAND/archive/refs/tags/rocm-${rocrand_version}.tar.gz -tar -xzvf rocm-${rocrand_version}.tar.gz -rm rocm-${rocrand_version}.tar.gz -cd rocRAND-rocm-${rocrand_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocRAND-rocm-${rocrand_version} - -#hipcub -hipcub_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/hipCUB/archive/refs/tags/rocm-${hipcub_version}.tar.gz -tar -xzvf rocm-${hipcub_version}.tar.gz -rm rocm-${hipcub_version}.tar.gz -cd hipCUB-rocm-${hipcub_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make package -make install -cd ../.. -rm -rf hipCUB-rocm-${hipcub_version} - -#rocprim -rocprim_version=4.5.2 -wget --quiet https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/refs/tags/rocm-${rocprim_version}.tar.gz -tar -xzvf rocm-${rocprim_version}.tar.gz -rm rocm-${rocprim_version}.tar.gz -cd rocPRIM-rocm-${rocprim_version} -mkdir build -cd build -CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. -make -j8 -make install -cd ../.. -rm -rf rocPRIM-rocm-${rocprim_version} - diff --git a/js/package-lock.json b/js/package-lock.json index 1e9f5cb29fe6c..0fca515b61238 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4,6 +4,7 @@ "requires": true, "packages": { "": { + "name": "js", "license": "MIT", "devDependencies": { "@eslint/compat": "^1.4.0", @@ -3230,6 +3231,27 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "dev": true, + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -3242,6 +3264,32 @@ "node": ">=10.13.0" } }, + "node_modules/glob/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/glob/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/global-agent": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz", @@ -4311,43 +4359,6 @@ "balanced-match": "^1.0.0" } }, - "node_modules/mocha/node_modules/glob": { - "version": "10.4.5", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", - "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", - "dev": true, - "license": "ISC", - "dependencies": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "bin": { - "glob": "dist/esm/bin.mjs" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/mocha/node_modules/glob/node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "license": "ISC", - "dependencies": { - "brace-expansion": "^2.0.1" - }, - "engines": { - "node": ">=16 || 14 >=14.17" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/mocha/node_modules/minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", @@ -8078,6 +8089,40 @@ "get-intrinsic": "^1.2.6" } }, + "glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "dev": true, + "requires": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "dependencies": { + "brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "requires": { + "balanced-match": "^1.0.0" + } + }, + "minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "requires": { + "brace-expansion": "^2.0.1" + } + } + } + }, "glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -8772,31 +8817,6 @@ "balanced-match": "^1.0.0" } }, - "glob": { - "version": "10.4.5", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", - "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", - "dev": true, - "requires": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "dependencies": { - "minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "requires": { - "brace-expansion": "^2.0.1" - } - } - } - }, "minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index e6ed2bdb9e17b..de8d631362db7 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -33,6 +33,7 @@ "version": "1.24.0", "license": "MIT", "devDependencies": { + "globby": "^15.0.0", "typedoc": "^0.25.7" } }, @@ -61,15 +62,15 @@ } }, "node_modules/@babel/code-frame": { - "version": "7.26.2", - "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", - "integrity": "sha512-RJlIHRueQgwWitWgF8OdFYGZX328Ax5BCemNGlqHfplnRT9ESi8JkFlvaVYbS+UubVY6dpv87Fs2u5M29iNFVQ==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", + "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-validator-identifier": "^7.25.9", + "@babel/helper-validator-identifier": "^7.27.1", "js-tokens": "^4.0.0", - "picocolors": "^1.0.0" + "picocolors": "^1.1.1" }, "engines": { "node": ">=6.9.0" @@ -410,9 +411,9 @@ } }, "node_modules/@babel/helper-string-parser": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", - "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", "dev": true, "license": "MIT", "engines": { @@ -420,9 +421,9 @@ } }, "node_modules/@babel/helper-validator-identifier": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", - "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", + "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", "dev": true, "license": "MIT", "engines": { @@ -455,27 +456,27 @@ } }, "node_modules/@babel/helpers": { - "version": "7.25.6", - "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.25.6.tgz", - "integrity": "sha512-Xg0tn4HcfTijTwfDwYlvVCl43V6h4KyVVX2aEm4qdO/PC6L2YvzLHFdmxhoeSA3eslcE6+ZVXHgWwopXYLNq4Q==", + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.28.4.tgz", + "integrity": "sha512-HFN59MmQXGHVyYadKLVumYsA9dBFun/ldYxipEjzA4196jpLZd8UjEEBLkbEkvfYreDqJhZxYAWFPtrfhNpj4w==", "dev": true, "license": "MIT", "dependencies": { - "@babel/template": "^7.25.0", - "@babel/types": "^7.25.6" + "@babel/template": "^7.27.2", + "@babel/types": "^7.28.4" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/parser": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.9.tgz", - "integrity": "sha512-81NWa1njQblgZbQHxWHpxxCzNsa3ZwvFqpUg7P+NNUU6f3UU2jBEg4OlF/J6rl8+PQGh1q6/zWScd001YwcA5A==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.5.tgz", + "integrity": "sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==", "dev": true, "license": "MIT", "dependencies": { - "@babel/types": "^7.26.9" + "@babel/types": "^7.28.5" }, "bin": { "parser": "bin/babel-parser.js" @@ -2114,35 +2115,25 @@ } }, "node_modules/@babel/runtime": { - "version": "7.25.6", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.6.tgz", - "integrity": "sha512-VBj9MYyDb9tuLq7yzqjgzt6Q+IBQLrGZfdjOekyEirZPHxXWoTSGUTMrpsfi58Up73d13NfYLv8HT9vmznjzhQ==", + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz", + "integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==", "dev": true, "license": "MIT", - "dependencies": { - "regenerator-runtime": "^0.14.0" - }, "engines": { "node": ">=6.9.0" } }, - "node_modules/@babel/runtime/node_modules/regenerator-runtime": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", - "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==", - "dev": true, - "license": "MIT" - }, "node_modules/@babel/template": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.26.9.tgz", - "integrity": "sha512-qyRplbeIpNZhmzOysF/wFMuP9sctmh2cFzRAZOn1YapxBsE1i9bJIY586R/WBLfLcmcBlM8ROBiQURnnNy+zfA==", + "version": "7.27.2", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", + "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", "dev": true, "license": "MIT", "dependencies": { - "@babel/code-frame": "^7.26.2", - "@babel/parser": "^7.26.9", - "@babel/types": "^7.26.9" + "@babel/code-frame": "^7.27.1", + "@babel/parser": "^7.27.2", + "@babel/types": "^7.27.1" }, "engines": { "node": ">=6.9.0" @@ -2189,14 +2180,14 @@ "license": "MIT" }, "node_modules/@babel/types": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.9.tgz", - "integrity": "sha512-Y3IR1cRnOxOCDvMmNiym7XpXQ93iGDDPHx+Zj+NM+rg0fBaShfQLkg+hKPaZCEvg5N/LeCo4+Rj/i3FuJsIQaw==", + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.5.tgz", + "integrity": "sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-string-parser": "^7.25.9", - "@babel/helper-validator-identifier": "^7.25.9" + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.28.5" }, "engines": { "node": ">=6.9.0" @@ -3319,9 +3310,9 @@ } }, "node_modules/babel-plugin-module-resolver/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -3477,7 +3468,9 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.11", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", "dev": true, "license": "MIT", "dependencies": { @@ -3831,9 +3824,9 @@ } }, "node_modules/compression": { - "version": "1.8.0", - "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.0.tgz", - "integrity": "sha512-k6WLKfunuqCYD3t6AsuPGvQWaKwuLLh2/xHNcX4qE+vIfDNXpSqnrhwA7O53R7WVQUnt8dVAIW+YHr7xTgOgGA==", + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz", + "integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==", "dev": true, "license": "MIT", "dependencies": { @@ -3841,7 +3834,7 @@ "compressible": "~2.0.18", "debug": "2.6.9", "negotiator": "~0.6.4", - "on-headers": "~1.0.2", + "on-headers": "~1.1.0", "safe-buffer": "5.2.1", "vary": "~1.1.2" }, @@ -4821,9 +4814,9 @@ } }, "node_modules/image-size": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.0.tgz", - "integrity": "sha512-4S8fwbO6w3GeCVN6OPtA9I5IGKkcDMPcKndtUlpJuCwu7JLjtj7JZpwqLuyY2nrmQT3AWsCJLSKPsc2mPBSl3w==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.1.tgz", + "integrity": "sha512-rH+46sQJ2dlwfjfhCyNx5thzrv+dtmBIhPHk0zgRUukHzZ/kRueTJXoYYsclBaKcSMBWuGbOFXtioLpzTb5euw==", "dev": true, "license": "MIT", "dependencies": { @@ -5250,7 +5243,9 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "3.14.1", + "version": "3.14.2", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.2.tgz", + "integrity": "sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==", "dev": true, "license": "MIT", "dependencies": { @@ -6544,9 +6539,9 @@ } }, "node_modules/on-headers": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", - "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz", + "integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==", "dev": true, "license": "MIT", "engines": { @@ -7130,9 +7125,9 @@ "license": "Python-2.0" }, "node_modules/react-native-builder-bob/node_modules/brace-expansion": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", - "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", "dev": true, "license": "MIT", "dependencies": { @@ -7203,9 +7198,9 @@ } }, "node_modules/react-native-builder-bob/node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 6a8dffb73fa08..f0f7527f665b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -360,7 +360,7 @@ const createInPlaceSoftmaxProgramInfo = ( let local_offset = local_idx * uniforms.elements_per_thread; let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset; let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'}; - var thread_max_vector = ${f32Type}(-3.402823e+38f); + var thread_max_vector = ${f32Type}(-3.4028234663852886e+38f); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } @@ -378,7 +378,7 @@ const createInPlaceSoftmaxProgramInfo = ( })()}; workgroupBarrier(); - var max_value = f32(-3.402823e+38f); + var max_value = f32(-3.4028234663852886e+38f); for (var i = 0u; i < ${WG}; i++) { max_value = max(thread_max[i], max_value); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 2056416873df5..f6882280e91df 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -81,7 +81,7 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt // 6.2.4 in wgsl spec const threadMaxDecl = tensorTypeToWsglStorageType(transposedInput.dataType) === 'f32' - ? `var threadMax = ${valueType}(-3.402823e+38f);` + ? `var threadMax = ${valueType}(-3.4028234663852886e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (shaderHelper: ShaderHelper) => ` var rowMaxShared : ${valueType}; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu deleted file mode 100644 index b40fc2bf0eef8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cu +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/attention.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" -#include "contrib_ops/rocm/bert/transformer_common.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr int kPastSequenceLengthInputIndex = 6; -constexpr int kPastInputIndex = 4; -constexpr int kPresentOutputIndex = 1; - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Attention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(kPastInputIndex, kPresentOutputIndex) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ - Attention); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -template -Attention::Attention(const OpKernelInfo& info) - : RocmKernel(info), AttentionBase(info, true), attn_type_(kAttention) { - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - tunable_op_ = std::make_shared(); -} - -template -Status Attention::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* weights = context->Input(1); - const Tensor* bias = context->Input(2); - const Tensor* mask_index = context->Input(3); - const Tensor* past = context->Input(4); - const Tensor* attention_bias = context->Input(5); - const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); - - auto& device_prop = GetDeviceProp(); - RocmAttentionParameters attn; - ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), - weights->Shape(), - bias->Shape(), - mask_index, - past, - attention_bias, - &attn, - device_prop.maxThreadsPerBlock, - past_seq_len)); - ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention - ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(attn.batch_size); - output_shape[1] = static_cast(attn.sequence_length); - output_shape[2] = static_cast(attn.v_hidden_size); - Tensor* output = context->Output(0, output_shape); - - std::vector present_dims{ - 2, attn.batch_size, attn.num_heads, - past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, - attn.head_size}; - TensorShape present_shape(present_dims); - Tensor* present = context->Output(kPresentOutputIndex, present_shape); - - auto stream = Stream(context); - hipblasHandle_t hipblas = GetHipblasHandle(context); - - using HipT = typename ToHipType::MappedType; - using QkvProjectGeneric = GemmPermuteGenericPipeline; - using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - - ORT_RETURN_IF_ERROR(ClassifyAttentionMode(attn_type_, &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present})); - ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE || - attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE || - attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE || - attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE || - attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE); - - size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn); - size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn), - AttentionGeneric::GetWorkspaceNumBytes(&attn)); - if (GetTuningContext()->IsTunableOpEnabled()) { - shared_workspace_bytes = std::max(shared_workspace_bytes, AttentionTunableOp::GetWorkspaceNumBytes(&attn)); - } - - auto qkv_project_output = GetScratchBuffer(qkv_project_output_bytes, context->GetComputeStream()); - auto workspace = GetScratchBuffer(shared_workspace_bytes, context->GetComputeStream()); - - GemmPermuteParams gemm_permute_params; - { - auto& params = gemm_permute_params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = hipblas; - params.attention = &attn; - params.device_prop = &device_prop; - - params.input_buffer = reinterpret_cast(input->DataRaw()); - params.weight_buffer = reinterpret_cast(weights->DataRaw()); - params.bias_buffer = reinterpret_cast(bias->DataRaw()); - params.out_buffer = reinterpret_cast(qkv_project_output.get()); - params.ones = GetConstOnes(attn.batch_size * attn.sequence_length, stream); - params.workspace_buffer = reinterpret_cast(workspace.get()); - } - - ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params)); - auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params); - - // NOTE: GemmPermute always output 3BNSH, k_buffer and v_buffer can be treated as 2BNSH - if (nullptr != present) { - Strides dst_strides; // the output buffer is present Tensor, the buffer is the same - - int4 add_shape{2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size}; - HipT* add_dest = nullptr; // destination of concatenated data to present - const HipT* const add_src = k_buffer; // source of concatenated data to present - const auto add_src_strides = Strides::BNSHMemory( - 2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size); - - if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; - } else if (attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - - // We only need to copy past to present in this case. All other cases will be build the present incrementally - const int4 past_shape = {2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; - HipT* const past_dest = reinterpret_cast(present->MutableDataRaw()); - const HipT* const past_src = reinterpret_cast(past->DataRaw()); - const Strides past_src_strides = Strides::BNSHMemory( - 2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); - - ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, past_src, past_shape, past_src_strides.ForBNSHCoord(), - past_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } else if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; - } else if (attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE) { - dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - } - - ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, add_src, add_shape, add_src_strides.ForBNSHCoord(), - add_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - - // update pointers to present_k and present_v. TODO: switch to ConvertToOffsetedBufferViews - k_buffer = reinterpret_cast(present->MutableDataRaw()); - v_buffer = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(attn.batch_size, 0, 0, 0); - } - - // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax - const TransformerOptions* options = TransformerOptions::GetInstance(); - bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); - - GemmSoftmaxGemmPermuteParams gemm_softmax_gemm_permute_params; - { - auto& params = gemm_softmax_gemm_permute_params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = hipblas; - params.attention = &attn; - params.device_prop = &device_prop; - // FIXME: the params.scale seems to be different from AttentionParameters::scale; - params.scale = 1.0f / sqrt(static_cast(attn.head_size)); - // TODO: switch to ConvertToOffsetedBufferViews - params.q_buffer = q_buffer; - params.k_buffer = k_buffer; - params.v_buffer = v_buffer; - params.out_buffer = reinterpret_cast(output->MutableDataRaw()); - - if (attention_bias != nullptr) { - params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); - } - - if (mask_index != nullptr) { - params.mask_index_buffer = mask_index->Data(); - params.mask_index_dims = mask_index->Shape().AsShapeVector(); - } - - params.workspace_buffer = reinterpret_cast(workspace.get()); - } - - if (this->GetTuningContext()->IsTunableOpEnabled() && - !use_persistent_softmax) { - return (*std::static_pointer_cast(tunable_op_))(&gemm_softmax_gemm_permute_params); - } else { - return AttentionGeneric::Run(&gemm_softmax_gemm_permute_params, use_persistent_softmax); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.h b/onnxruntime/contrib_ops/rocm/bert/attention.h deleted file mode 100644 index 7204fd660a516..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class Attention final : public RocmKernel, public AttentionBase { - public: - Attention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - public: - AttentionType attn_type_; - - // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: - // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. - // 2. We don't want to construct the object repeatly (which is expansive) during Compute. - std::shared_ptr tunable_op_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu deleted file mode 100644 index 270a8e51daf88..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ /dev/null @@ -1,435 +0,0 @@ -/* - The implementation of this file is based on qkvToContext plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Modifications: scaling is moved from masked softmax to the gemm before that. -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/attention_softmax.h" -#include "contrib_ops/rocm/bert/decoder_attention_impl.h" - -using namespace onnxruntime::rocm; - -namespace blas = onnxruntime::rocm::tunable::blas; - -#define CHECK_ROCM(expr) HIP_RETURN_IF_ERROR(expr) - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -static size_t AlignTo(size_t a, size_t b) { - return CeilDiv(a, b) * b; -} - -size_t GetAttentionScratchSize(size_t element_size, - int batch_size, - int num_heads, - int sequence_length, - int total_sequence_length) { - const size_t bytes = element_size * batch_size * num_heads * sequence_length * total_sequence_length; - - const size_t alignment = 256; - const size_t bytesAligned = AlignTo(bytes, alignment); - return bytesAligned; -} - -size_t GetAttentionWorkspaceSize( - size_t element_size, - int batch_size, - int num_heads, - int head_size, - int sequence_length, - int total_sequence_length) { - size_t qkv_size = element_size * 3 * batch_size * sequence_length * num_heads * head_size; - return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, - sequence_length, total_sequence_length); -} - -inline int3 Get2DMaskStrides(int total_sequence_length) { - // stride == 0 indicate broadcasting - return {total_sequence_length, 0, 1}; -} - -Status ClassifyAttentionMode( - AttentionType attn_type, - RocmAttentionParameters* attn, - const std::vector& qkv, - const std::vector& past, - const std::vector& present) { - size_t num_qkv = std::count_if(qkv.cbegin(), qkv.cend(), [](auto it) { return it != nullptr; }); - size_t num_past = std::count_if(past.cbegin(), past.cend(), [](auto it) { return it != nullptr; }); - size_t num_present = std::count_if(present.cbegin(), present.cend(), [](auto it) { return it != nullptr; }); - - auto hint = MakeString(num_qkv, " qkv inputs, ", num_past, " past inputs and ", num_present, " present inputs"); - LOGS_DEFAULT(VERBOSE) << hint; - - if (attn_type == kAttention) { - ORT_ENFORCE(num_qkv == 0); - if (num_past == 0 && num_present == 0) { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE; - return Status::OK(); - } else if (num_past == 0 && num_present == 1) { - if (attn->past_present_share_buffer == false) { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE; - return Status::OK(); - } else { - attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE; - return Status::OK(); - } - } else if (num_past == 1 && num_present == 1) { - if (attn->past_present_share_buffer == false) { - attn->mode = QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE; - return Status::OK(); - } else { - attn->mode = QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE; - return Status::OK(); - } - } - } else if (attn_type == kMultiHeadAttention || attn_type == kDecoderMaskedMultiHeadAttention) { - if (num_qkv == 3 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } else if (num_qkv == 3 && num_past == 0 && num_present == 2) { - if (attn->past_present_share_buffer == false) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; - return Status::OK(); - } - } else { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; - return Status::OK(); - } - } - } else if (num_qkv == 3 && num_past == 2 && num_present == 2) { - if (attn->past_present_share_buffer == false) { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; - return Status::OK(); - } - } else { - if (attn->qkv_format == Q_K_V_BSNH) { - attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; - return Status::OK(); - } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { - attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; - return Status::OK(); - } - } - } else if (num_qkv == 1 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == QKV_BSN3H) { - attn->mode = BLN3H_NONE_NONE_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } else if (num_qkv == 2 && num_past == 0 && num_present == 0) { - if (attn->qkv_format == Q_KV_BSNH_BSN2H) { - attn->mode = BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE; - return Status::OK(); - } - } - } - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Unsupported AttentionMode for ", attn_type, ". Got qkv format ", attn->qkv_format, - ". Got ", hint); -} - -template -Status DecoderQkvToContext( - const hipDeviceProp_t& prop, - RocmTuningContext* tuning_ctx, - Stream* ort_stream, - hipblasHandle_t& hipblas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) { - const int max_threads_per_block = prop.maxThreadsPerBlock; - const int BN = batch_size * num_heads; - const int BHN = BN * head_size; - const int BNS = BN * sequence_length; - const int k_buffer_offset = sequence_length * BHN; - const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; - - T* temp_qkv_buffer = workspace_buffer; - auto stream = static_cast(ort_stream->GetHandle()); - - const T* q = qkv_buffer; - // transpose q and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, - num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); - - const T* k = qkv_buffer + k_buffer_offset; - const T* v = qkv_buffer + v_buffer_offset; - if (!has_layer_state || !use_past) { - if (!static_kv) { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } else { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } - } else { - if (!static_kv) { - // transpose kv and copy them to temp_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); - // concat cache-k with k and copy to qkv_buffer - if (nullptr != key_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, - batch_size, head_size, num_heads, - max_threads_per_block, 1, key_cache, - temp_qkv_buffer, qkv_buffer + k_buffer_offset)); - } - // concat cache-v with v and copy to qkv_buffer - if (nullptr != value_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, - batch_size, head_size, num_heads, - max_threads_per_block, 1, value_cache, - temp_qkv_buffer + k_buffer_offset, - qkv_buffer + v_buffer_offset)); - } - } - } - - if (has_layer_state) { - if (use_past && static_kv) { - CHECK_ROCM(hipMemcpyAsync(new_key_cache, key_cache, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - CHECK_ROCM(hipMemcpyAsync(new_value_cache, value_cache, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } else { - CHECK_ROCM(hipMemcpyAsync(new_key_cache, k, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - CHECK_ROCM(hipMemcpyAsync(new_value_cache, v, - kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); - } - } - - // scratch1: BxNxSxS* buffer - // scratch2: BxNxSxS* buffer - // scratch3: BxNxSxH buffer - T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; - T* scratch2 = scratch1 + BNS * kv_sequence_length; - T* scratch3 = scratch2 + BNS * kv_sequence_length; - - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* - // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int temp_matrix_size = sequence_length * kv_sequence_length; - - const int strideA = kv_sequence_length * head_size; - const int strideB = sequence_length * head_size; - if (use_past && static_kv) { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::Trans, blas::BlasOp::NonTrans, - kv_sequence_length, sequence_length, head_size, - /*alpha=*/rsqrt_head_size, - key_cache, head_size, strideA, - q, head_size, strideB, - /*beta=*/0.0f, - scratch1, kv_sequence_length, temp_matrix_size, - BN)); - } else { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::Trans, blas::BlasOp::NonTrans, - kv_sequence_length, sequence_length, head_size, - /*alpha=*/rsqrt_head_size, - k, head_size, strideA, - q, head_size, strideB, - /*beta=*/0.0f, - scratch1, kv_sequence_length, temp_matrix_size, - BN)); - } - - if (has_key_padding_mask) { - int3 strides = Get2DMaskStrides(kv_sequence_length); - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( - ort_stream, kv_sequence_length, sequence_length, batch_size, num_heads, - strides, nullptr, key_padding_mask, nullptr, scratch1, scratch2, - false, 1.0f, false, nullptr, mask_filter_value)); - } else { - ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, scratch1, scratch2, false)); - } - - // compute P*V (as V*P), and store in scratch3: BxNxSxH - if (use_past && static_kv) { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - head_size, sequence_length, kv_sequence_length, - /*alpha=*/1.0f, - value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - /*beta=*/0.0f, - scratch3, head_size, strideB, - BN)); - } else { - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, ort_stream, hipblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - head_size, sequence_length, kv_sequence_length, - /*alpha=*/1.0f, - v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - /*beta=*/0.0f, - scratch3, head_size, strideB, - BN)); - } - - // scratch3 is BxNxSxH, transpose to output SxBxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, - num_heads, max_threads_per_block, true, scratch3, output); -} - -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, - RocmTuningContext* tuning_ctx, - Stream* stream, - hipblasHandle_t& hipblas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) { - if (element_size == 2) { - return DecoderQkvToContext( - prop, - tuning_ctx, - stream, - hipblas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } else { - return DecoderQkvToContext( - prop, - tuning_ctx, - stream, - hipblas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h deleted file mode 100644 index 07d875e90fa4b..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "contrib_ops/cpu/bert/attention_common.h" -#include "contrib_ops/cpu/bert/attention_parameters.h" -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -typedef struct __align__(32) { - long long int x, y, z, w; -} LongLong4; - -size_t GetAttentionScratchSize( - size_t element_size, - int batch_size, - int num_heads, - int sequence_length, - int all_sequence_length); - -size_t GetAttentionWorkspaceSize( - size_t element_size, - int batch_size, - int num_heads, - int head_size, - int sequence_length, - int past_sequence_length); - -Status LaunchTransCtx(hipStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); - -Status LaunchTransCtx(hipStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const half* input, half* output); - -Status LaunchTransQkv(hipStream_t stream, const int matrix_num, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const float* input, float* output, - int total_matrix_count = -1); - -Status LaunchTransQkv(hipStream_t stream, const int matrix_num, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, - int total_matrix_count = -1); - -Status LaunchConcatTensorToTensor(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const float* tensor_in, - const float* tensor_add, - float* tensor_out); - -Status LaunchConcatTensorToTensor(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const half* tensor_in, - const half* tensor_add, - half* tensor_out); - -inline hipblasStatus_t _compat_hipblas_gemm_strided_batched_ex(hipblasHandle_t handle, - hipblasOperation_t transa, - hipblasOperation_t transb, - int m, - int n, - int k, - const void* alpha, - const void* A, - hipDataType a_type, - int lda, - hipblasStride stride_A, - const void* b, - hipDataType b_type, - int ldb, - hipblasStride stride_b, - const void* beta, - void* c, - hipDataType c_type, - int ldc, - hipblasStride stride_c, - int batch_count, - hipblasComputeType_t compute_type, - hipblasGemmAlgo_t algo) { - return hipblasGemmStridedBatchedEx(handle, - transa, - transb, - m, // m - n, // n - k, // k - alpha, // alpha - A, // A - a_type, // A type - lda, // lda - stride_A, // strideA - b, // B - b_type, // B type - ldb, // ldb - stride_b, // strideB - beta, // beta - c, // C - c_type, // C type - ldc, // ldc - stride_c, // strideC - batch_count, // batch count - compute_type, - algo); -} - -// Compatible for CublasMathModeSetter -class CompatHipblasMathModeSetter { - public: - CompatHipblasMathModeSetter(const hipDeviceProp_t&, - hipblasHandle_t, - int) { - } -}; - -enum AttentionMode { - // Q,K,V,PastK,PastV,PresentK,PresentV - QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, - QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE, - QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE, - QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE, - QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE, - BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE, - BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE, - BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH, - BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH, - BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH, - BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH, - BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH, - BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH, - BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH, - BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH, - BLN3H_NONE_NONE_NONE_NONE_NONE_NONE, - BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE, -}; - -struct RocmAttentionParameters : AttentionParameters { - AttentionMode mode; -}; - -Status ClassifyAttentionMode(AttentionType type, - RocmAttentionParameters* attn, - const std::vector& qkv, - const std::vector& past, - const std::vector& present); - -template -Status LaunchStridedCopy( - hipStream_t stream, - const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) - T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) - int max_threads_per_block); - -template -Status LaunchStridedCopy(hipStream_t stream, - const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h) - T* out, LongLong4 out_strides, // coord (b,n,s,h) - int max_threads_per_block); -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h deleted file mode 100644 index 9f2faa228cf79..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ /dev/null @@ -1,465 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on qkvToContext plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#pragma once - -#include -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/math/softmax.h" - -#define ROCMRT_INF_F __int_as_float(0x7f800000) - -using namespace onnxruntime::rocm; -using namespace hipcub; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -__device__ inline void Softmax(const int all_sequence_length, - const int valid_end, - const int valid_start, - const T* attn_bias, - const T* input, - T* output) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - float thread_data_max(-ROCMRT_INF_F); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - for (int i = threadIdx.x; i < valid_end; i += TPB) { - if (i >= valid_start) { - const int index = offset + i; - float input_at_idx = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - if (thread_data_max < input_at_idx) { - thread_data_max = input_at_idx; - } - } - } - - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max()); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_sum(0.f); - for (int i = threadIdx.x; i < valid_end; i += TPB) { - if (i >= valid_start) { - const int index = offset + i; - float val = attn_bias == nullptr ? input[index] : input[index] + attn_bias[index]; - thread_data_sum += expf(val - max_block); - } - } - - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, hipcub::Sum()); - if (threadIdx.x == 0) { - sum_reverse_block = 1.f / sum; - } - __syncthreads(); - - for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { - const int index = offset + i; - float input_at_idx = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; - output[index] = T(val); - } -} - -template -__device__ inline void SoftmaxSmall(const int all_sequence_length, - const int sequence_length, - const int valid_end, - const int valid_start, - const T* attn_bias, - const T* input, - T* output, - bool causal) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - const int index = offset + threadIdx.x; - - bool is_valid = false; // whether it has attention mask == 1. - - // Update end position for causal. - int end = valid_end; - if (causal) { - const int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; - if (end_causal < end) { - end = end_causal; - } - } - - is_valid = (threadIdx.x >= valid_start && threadIdx.x < end); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float input_data = attn_bias == nullptr - ? static_cast(input[index]) - : static_cast(input[index] + attn_bias[index]); - float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F); - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_exp(0.f); - if (is_valid) { - thread_data_exp = expf(input_data - max_block); - } - - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), end); - - // Store value of 1.0/sum. - if (threadIdx.x == 0) { - sum_reverse_block = (1.f) / sum; - } - __syncthreads(); - - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. - if (threadIdx.x < all_sequence_length) { - output[index] = is_valid ? T(thread_data_exp * sum_reverse_block) : T(0.f); - } -} - -// Note about the attention_mask_strides and attention_mask/key_padding_mask -// attention_mask accepts 2D, 3D or 4D tensor, but it will be viewed as 3D tensor uniformally and it will be indexed -// as [batch_index, sequence_index, token_index]. -template -__global__ void SoftmaxWithRawMaskSmallKernel( - const int all_sequence_length, - const int sequence_length, - const int3 attention_mask_strides, - const int* attention_mask, // 2D, 3D or 4D attention mask - const bool* key_padding_mask, - const T* attn_bias, - const T* input, - T* output, - const bool causal, - const float rsqrt_head_size, - const bool skip_softmax, - const float mask_filter_value) { - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; - - // Mask all thread_data values to negative infinity to allow BlockReduce Max operation over all thread_data - // members with all invalid members set to a value that does not impact the final result. This is necessary - // to avoid the performance impact from using the valid_items interface. - float thread_data = -ROCMRT_INF_F; - if (threadIdx.x < all_sequence_length) { - thread_data = float(input[index]) * rsqrt_head_size; - - const int sequence_index = blockIdx.x % sequence_length; - if (causal) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. - if (threadIdx.x > from_index) { - thread_data = -ROCMRT_INF_F; - } - } - - const int batch_index = blockIdx.y; - int mask_offset = attention_mask_strides.x * batch_index + - attention_mask_strides.y * sequence_index + - attention_mask_strides.z * threadIdx.x; - - if (nullptr == key_padding_mask) { - const int& mask = attention_mask[mask_offset]; - if (mask == 0) - thread_data += mask_filter_value; - } else { - const bool mask = key_padding_mask[mask_offset]; - if (mask) { - thread_data = -ROCMRT_INF_F; - } - } - - if (attn_bias != nullptr) { - thread_data += float(attn_bias[index]); - } - } - - if (skip_softmax) { - if (threadIdx.x < all_sequence_length) { - output[index] = T(thread_data); - } - return; - } - - const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max()); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - // Mask all thread_data_exp values to zero to allow BlockReduce Sum operation over all thread_data_exp - // members with all invalid members set to a value that does not impact the final result. This is necessary - // to avoid the performance impact from using the valid_items interface. - float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum()); - - // Store value of 1.0/sum - if (threadIdx.x == 0) { - sum_reverse_block = (1.f) / sum; - } - __syncthreads(); - - if (threadIdx.x < all_sequence_length) { - output[index] = T(thread_data_exp * sum_reverse_block); - } -} - -template -__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, - const T* attn_bias, const T* input, T* output, bool causal) { - SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, - attn_bias, input, output, causal); -} - -template -__global__ void SoftmaxKernel(const int all_sequence_length, const T* attn_bias, const T* input, T* output) { - Softmax(all_sequence_length, all_sequence_length, 0, attn_bias, input, output); -} - -template -Status ComputeSoftmax( - hipStream_t stream, - const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const T* attn_bias, const T* input, T* output, bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - if (all_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, attn_bias, input, output, causal); - } else if (!causal) { - const int blockSize = 1024; - SoftmaxKernel<<>>( - all_sequence_length, attn_bias, input, output); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - - return HIP_CALL(hipPeekAtLastError()); -} - -template -__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, - const int* mask_end, const int* mask_start, - const T* attn_bias, const T* input, T* output, - bool causal) { - __shared__ int start_position; - __shared__ int end_position; - - if (threadIdx.x == 0) { - const int batch = blockIdx.y; - start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); - - // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. - if (start_position >= end_position) { - start_position = 0; - end_position = all_sequence_length; - } - } - __syncthreads(); - - SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, - attn_bias, input, output, causal); -} - -template -__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int* mask_end, const int* mask_start, - const T* attn_bias, const T* input, T* output) { - __shared__ int start_position; - __shared__ int end_position; - - if (threadIdx.x == 0) { - const int batch = blockIdx.y; - start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); - - // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. - if (start_position >= end_position) { - start_position = 0; - end_position = all_sequence_length; - } - } - __syncthreads(); - - Softmax(all_sequence_length, end_position, start_position, attn_bias, input, output); -} - -template -Status ComputeSoftmaxWithMask1D( - hipStream_t stream, - const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* mask_index, const int* mask_start, - const T* attn_bias, const T* input, T* output, const bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - -#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ - MaskedSoftmaxKernelSmall<<>>( \ - all_sequence_length, sequence_length, mask_index, mask_start, \ - attn_bias, input, output, causal); - - if (all_sequence_length <= 32) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); - } else if (all_sequence_length <= 64) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); - } else if (all_sequence_length <= 128) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); - } else if (all_sequence_length <= 256) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); - } else if (all_sequence_length <= 512) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); - } else if (all_sequence_length <= 1024) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); - } else if (!causal) { - const int blockSize = 1024; - MaskedSoftmaxKernel<<>>( - all_sequence_length, mask_index, mask_start, - attn_bias, input, output); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - -#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE - - return HIP_CALL(hipPeekAtLastError()); -} - -template -Status ComputeSoftmaxWithRawMask(Stream* ort_stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int num_heads, - const int3 attention_mask_strides, - const int* attention_mask, - const bool* key_padding_mask, - const T* attn_bias, - const T* input, - T* output, - const bool causal, - const float rsqrt_head_size, - const bool use_persistent_softmax, - T* persistent_softmax_workspace, - const float mask_filter_value) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - - T* out = use_persistent_softmax ? persistent_softmax_workspace : output; - auto stream = static_cast(ort_stream->GetHandle()); - -#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ - SoftmaxWithRawMaskSmallKernel<<>>( \ - all_sequence_length, sequence_length, attention_mask_strides, \ - attention_mask, key_padding_mask, attn_bias, input, out, \ - causal, rsqrt_head_size, \ - use_persistent_softmax, mask_filter_value); - - if (all_sequence_length <= 32) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); - } else if (all_sequence_length <= 64) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); - } else if (all_sequence_length <= 128) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); - } else if (all_sequence_length <= 256) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); - } else if (all_sequence_length <= 512) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); - } else if (all_sequence_length <= 1024) { - DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); - } - -#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE - - if (use_persistent_softmax) { - return dispatch_warpwise_softmax_forward(ort_stream, - output, - persistent_softmax_workspace, - all_sequence_length, - all_sequence_length, - batch_size * num_heads * sequence_length); - } - - return HIP_CALL(hipPeekAtLastError()); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh deleted file mode 100644 index 213940f132963..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/rocm_kernel.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_common.h" -#include "contrib_ops/cpu/bert/attention_parameters.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace blas = onnxruntime::rocm::tunable::blas; - -namespace { -std::tuple GetQkvProjectGemmMNKBatch(const AttentionParameters* attention) { - int m = attention->sequence_length; - int n = (attention->hidden_size + attention->hidden_size + attention->v_hidden_size); // q + k + v - int k = attention->input_hidden_size; - int batch = attention->batch_size; - return {m, n, k, batch}; -} -} // namespace - -template -struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams { - std::string Signature() const override { - auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(attention); - return MakeString("M", m, "_N", n, "_K", k, "_B", batch); - } - - hipblasHandle_t handle; - const AttentionParameters* attention; - const hipDeviceProp_t* device_prop; - - const T* input_buffer; - const T* weight_buffer; - const T* bias_buffer; - T* out_buffer; - - int3 bias_strides; - - const T* ones; // used for broadcasting bias if the underlying algorithm does not support strides - T* workspace_buffer; -}; - -template -struct GemmPermuteGenericPipeline { - inline static size_t GetOutputNumBytes(const AttentionParameters* attn) { - auto [m, n, _, batch] = GetQkvProjectGemmMNKBatch(attn); - return sizeof(T) * m * n * batch; - } - - inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) { - return GetOutputNumBytes(attn); - } - - inline static std::tuple GetGemmMNK(const GemmPermuteParams* params) { - auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(params->attention); - return {batch * m, n, k}; - } - - inline static std::tuple UnspliceOutputQKV(const GemmPermuteParams* params) { - auto* attn = params->attention; - int64_t batch = attn->batch_size * attn->num_heads; - int64_t num_elems_per_batch = attn->sequence_length * attn->head_size; - int64_t num_elems = batch * num_elems_per_batch; - auto q = params->out_buffer + 0 * num_elems; - auto k = params->out_buffer + 1 * num_elems; - auto v = params->out_buffer + 2 * num_elems; - return {q, k, v}; - } - - inline static Status BroadcastBias(const GemmPermuteParams* params) { - auto [m, n, k] = GetGemmMNK(params); - // Bias shape is (N), broadcast using B(M, N) = ones(M, 1) x bias(1, N). - // TODO: use custom kernel of expand to improve the performance. - return blas::row_major::Gemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, n, 1, - /*alpha=*/1.0f, - params->ones, 1, - params->bias_buffer, n, - /*beta=*/0.0f, - params->workspace_buffer, n); - } - - inline static Status Gemm(const GemmPermuteParams* params) { - auto [m, n, k] = GetGemmMNK(params); - // result(M, N) = input x weights + bias. - return blas::row_major::Gemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, n, k, - /*alpha=*/1.0f, - params->input_buffer, k, - params->weight_buffer, n, - /*beta=*/1.0f, - params->workspace_buffer, n); - } - - inline static Status Permute0213(const GemmPermuteParams* params) { - auto* attn = params->attention; - // input should be BxSx3xNxH => gemm_buffer: 3xBxNxSxH - return LaunchTransQkv( - params->StreamHandle(), 3, attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, - params->device_prop->maxThreadsPerBlock, false, params->workspace_buffer, params->out_buffer); - } - - static Status Run(const GemmPermuteParams* params) { - ORT_RETURN_IF_ERROR(BroadcastBias(params)); - ORT_RETURN_IF_ERROR(Gemm(params)); - ORT_RETURN_IF_ERROR(Permute0213(params)); - return Status::OK(); - } -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh deleted file mode 100644 index be8508670e4b1..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh +++ /dev/null @@ -1,177 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#ifdef USE_COMPOSABLE_KERNEL -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/utility/data_type.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using F16 = ck::half_t; -using F32 = float; - -template -using S = ck::Sequence; - -using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecialization; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; - -static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; - -template -using device_batched_gemm_softmax_gemm_permute_instances = - std::tuple< - // clang-format off - // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias| - // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar| - // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector| - // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, -#if ROCM_VERSION >= 50500 - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, -#endif - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, - // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> - // clang-format on - >; - -struct PreSoftmaxAttentionScoreOp { - PreSoftmaxAttentionScoreOp(float scale) : scale_(scale) {} - - // non-biased, non-masked - __host__ __device__ void operator()(float& y, const float& x) const { - y = scale_ * x; - } - - // biased or converted masked - __host__ __device__ void operator()(float& y, const float& x, const F16& bias) const { - y = scale_ * x + ck::type_convert(bias); - } - - // biased and converted masked - __host__ __device__ void operator()(float& y, const float& x, const F16& bias, const F16& converted_mask) const { - y = scale_ * x + ck::type_convert(bias) + ck::type_convert(converted_mask); - } - - float scale_; -}; - -// Use this function to gat implementation -template -std::vector, - PassThrough, PassThrough, D0Op, PassThrough, PassThrough, - MaskingSpec>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances() { - return {}; -} - -// implemented in impl_{fp16,bf16}[_biased][_masked].cu -// fp16, non-biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -// fp16, biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -// fp16, biased, fp16 masked, basically, two bias -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); - -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -// fp16, biased, non-masked -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -// fp16, biased, fp16 masked, basically, two bias -template <> -std::vector, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>>> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu deleted file mode 100644 index 2e32a6594d164..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using NonBiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu deleted file mode 100644 index 91da8d9e1f9a8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu deleted file mode 100644 index b08123be18977..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskDisabled>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskDisabled>{}); - - return instances; -} - -using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< - 2, 1, 1, 1, 1, - F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, - PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, - MaskingSpecialization::MaskOutUpperTriangle>; - -template <> -std::vector> -GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { - std::vector> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_batched_gemm_softmax_gemm_permute_instances< - 2, 1, 1, 1, 1, - F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, - MaskingSpecialization::MaskOutUpperTriangle>{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh deleted file mode 100644 index 226b89cfb2b86..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ /dev/null @@ -1,915 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -/* About Computing in these Pipelines - -B: batch size of Attention Op. NOTE: To be disambiguated with batch size of GEMMs -S: sequence length -T: total sequence length -N: num of heads -H: head dimension - -The following use qkv_format == Q_K_V_BNSH (mode == BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE) as a example: - -BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op - -In QKV projection (prior to this pipeline): - /-> Q [B,S,N*H] ->Reshape-> [B,S,N,H] ->Permute0213-> [B,N,S,H] -X --o--> K [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] - \-> V [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] - -pre_softmax_attn_scores = Q*K' = [B,N,S,H] * [BxNxTxH]' = [B,N,S,T] Batched GEMM1 -pre_softmax_attn_scores_masked = pre_softmax_attn_scores * scale +? bias +? mask Scale Add Bias, +? is optional -attn_scores = softmax(pre_softmax_attn_scores_masked) = [B,N,S,T] Softmax -scaled_multi_head_attn = attn_scores * V = [B,N,S,T] * [B,N,T,H] = [B,N,S,H] Batched GEMM2 - -Op outputs scaled_multi_head_attn: -[B,N,S,H] ->Permute0213-> [B,S,N,H] ->Reshape-> [B,S,N*H] - - -For the computing of pre_softmax_attn_scores +? mask +? bias: - -GemmSoftmaxGemmPermuteGenericPipeline handles it in specialized softmax. TODO: remove it! - -CK in GemmSoftmaxGemmPermuteTunablePipeline - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_2d ---> [B,T] ---> [B,1,1,T] -/ - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_3d --> [B,S,T] --> [B,1,S,T] -/ - - Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked - bias --------------> [B,N,S,T] --+?--/ -mask_4d -> [B,1,M,M] -> [B,1,S,T] -/ M is max_sequence_length from megatron, we will create a - **sub-view** from original mask buffer - -For CK implementation, there will be four cases combined: -non-biased, non-masked, no special processing. - biased, non-masked, no special processing, add the mask directly. -non-biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with scaled Q*K'. - biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with bias and scaled Q*K'. - -Broadcast add is not actually perform the broadcasting, just broadcast the load operation from memory. The impl details -are in composable kernels. The scale and add logic is performed via Acc0ElementOp - -# Classified modes: - -| Q | K | V | past(K)| pastV | present(K)| presentV | Op, desc -| ---- | ---- | ---- | ------ | ----- | --------- | -------- | --------- -| QFMT | KFMT | VFMT | - | - | - | - | A, basic, qkv is impl dependent by qkv_format -| QFMT | KFMT | VFMT | 2BNPH | - | 2BNTH *^ | - | A, past_present_share_buffer = false, qkv is impl dependent by qkv_format -| QFMT | KFMT | VFMT | 2BNMH | - | 2BNMH *^ | - | A, past_present_share_buffer = true, qkv is impl dependent by qkv_format -| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic -| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true -| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false -| BSNH | BLNH | BLNH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false -| BSNH | BNLH | BNLH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false -| BSNH | BLNH | BLNH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true -| BSNH | BNLH | BNLH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true -| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false -| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false -| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true -| BSNH | BNLH | BNLH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true -| BLN3H*^| - | - | - | - | - | - | MHA basic, qkv_packed -| BSNH | BLN2H*^| - | - | - | - | - | MHA basic, kv_packed - -Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel - -About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_v for v_buffer - -- Marked with `*` indicate the Tensor is used for k_buffer passing. -- Marked with `^` indicate the Tensor is used for v_buffer passing. - -# Supported Op - -- A: Attention -- MHA: MultiHeadAttention - -# Dim Value - -- B: batch_size -- N: num_heads -- H: head_size - -- S: sequence_length -- L: kv_sequence_length -- P: past_sequence_length -- T: total_sequence_length = P + L -- M: max_sequence_length -*/ - -#include "core/framework/tensor_shape.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/cpu/bert/attention_base.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/attention_softmax.h" -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif // USE_COMPOSABLE_KERNEL - -#include -#include - -namespace blas = onnxruntime::rocm::tunable::blas; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -inline int3 Get2DMaskStrides(int total_sequence_length) { - // stride == 0 indicate broadcasting - return {total_sequence_length, 0, 1}; -} - -// A stride maps from natural coordinate to physical offset of underlying memory storage buffer offset. We need to -// specify both of the natural coordinate order, say (b,n,s,h), (b,s,n,h) or (b,n,h,s), and memory order, say BNSH or -// BSNH, to determain the strides. To obtain the offset, we just do the inner product of coordinate with the strides. -// This wrapper create the stride vector from the physical dimension (or physical shape). -struct Strides { - // Create the strides for BNSH physically indexed memory buffer - static Strides BNSHMemory(int batch_dim, - int num_head_dim, - int seqlen_dim, - int head_size_dim) { - ORT_UNUSED_PARAMETER(batch_dim); - return Strides{LongLong4{ - static_cast(num_head_dim) * seqlen_dim * head_size_dim, - static_cast(seqlen_dim) * head_size_dim, - static_cast(head_size_dim), - static_cast(1), - }}; - } - - // Create the strides for BSNH physically indexed memory buffer - static Strides BSNHMemory(int batch_dim, - int seqlen_dim, - int num_head_dim, - int head_size_dim) { - ORT_UNUSED_PARAMETER(batch_dim); - return Strides{LongLong4{ - static_cast(seqlen_dim) * num_head_dim * head_size_dim, - static_cast(head_size_dim), - static_cast(num_head_dim) * head_size_dim, - static_cast(1), - }}; - } - - template - T ForBNSHCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.z), - static_cast(strides_for_bnsh_coord.w)}; - } - - template - T ForBSNHCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.z), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.w)}; - } - - template - T ForBNHSCoord() const { - using E = typename T::value_type; - return T{static_cast(strides_for_bnsh_coord.x), - static_cast(strides_for_bnsh_coord.y), - static_cast(strides_for_bnsh_coord.w), - static_cast(strides_for_bnsh_coord.z)}; - } - - int64_t OffsetAt(int b, int n, int s, int h) const { - return strides_for_bnsh_coord.x * b + strides_for_bnsh_coord.y * n + - strides_for_bnsh_coord.z * s + strides_for_bnsh_coord.w * h; - } - - // store intermediate strides in the canonical (b,n,s,h) coordinate order - LongLong4 strides_for_bnsh_coord; -}; - -template -std::tuple ConvertToOffsetedBufferViews( - const RocmAttentionParameters* attn, - const T* query = nullptr, // q or packed_qkv - const T* key = nullptr, // k or packed kv - const T* value = nullptr, // - const T* present = nullptr, // present or present_k - const T* present_v = nullptr) { - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: { - return {reinterpret_cast(query), - reinterpret_cast(key), - reinterpret_cast(value)}; - } - case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: { - auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->total_sequence_length * - attn->head_size; - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present) + offset}; - } - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: { - auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->max_sequence_length * - attn->head_size; - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present) + offset}; - } - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return {reinterpret_cast(query), - reinterpret_cast(present), - reinterpret_cast(present_v)}; - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: { - auto packed_kv = reinterpret_cast(key); - return {reinterpret_cast(query), packed_kv, packed_kv + attn->head_size}; - } - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: { - auto packed_qkv = reinterpret_cast(query); - return {packed_qkv, packed_qkv + 1 * attn->head_size, packed_qkv + 2 * attn->head_size}; - } - default: - ORT_ENFORCE("unreachable"); - return {}; - } -} - -inline std::tuple GetQkvStrides(const RocmAttentionParameters* attn) { - // G0 not used, because it is the slowest dimension - const int& B = attn->batch_size; - const int& N = attn->num_heads; - const int& S = attn->sequence_length; - const int& L = attn->kv_sequence_length; - const int& T = attn->total_sequence_length; - const int& M = attn->max_sequence_length; - const int& H = attn->head_size; - const int& Hv = attn->v_head_size; - - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - if (attn->qkv_format == Q_K_V_BNSH) { - return { - Strides::BNSHMemory(B, N, S, H), - Strides::BNSHMemory(B, N, L, H), - Strides::BNSHMemory(B, N, L, Hv), - }; - } else if (attn->qkv_format == Q_K_V_BSNH) { - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, H), - Strides::BSNHMemory(B, L, N, Hv), - }; - } - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, T, H), - Strides::BNSHMemory(B, N, T, Hv), - }; - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, M, H), - Strides::BNSHMemory(B, N, M, Hv), - }; - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, H), - Strides::BSNHMemory(B, L, N, Hv), - }; - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BNSHMemory(B, N, L, H), - Strides::BNSHMemory(B, N, L, Hv), - }; - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, S, N, H), - Strides::BSNHMemory(B, L, N, 2 * H), - Strides::BSNHMemory(B, L, N, 2 * Hv), - }; - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - return { - Strides::BSNHMemory(B, L, N, 3 * H), - Strides::BSNHMemory(B, L, N, 3 * H), - Strides::BSNHMemory(B, L, N, 3 * Hv), - }; - default: - ORT_ENFORCE("unreachable"); - return {}; - } -} - -inline std::tuple GetRawMaskBufferAddrSizesAndStrides( - const int* buffer, const RocmAttentionParameters* attn) { - const int* offseted_buffer{buffer}; // how to view the mask buffer - int3 sizes{0, 0, 0}; // the logical shape of the view - int3 strides{-1, -1, -1}; // the physical memory layout - switch (attn->mask_type) { - case MASK_NONE: - case MASK_2D_DUMMY: - break; // No mask - case MASK_2D_KEY_PADDING: - sizes = {attn->batch_size, 1, attn->total_sequence_length}; - strides = Get2DMaskStrides(attn->total_sequence_length); - break; - case MASK_3D_ATTENTION: - sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; - strides = {attn->sequence_length * attn->total_sequence_length, attn->total_sequence_length, 1}; - break; - case MASK_4D_MEGATRON: - // offset to skip past sequence part, so that we can index it with [batch_index, sequence_index, token_index] - offseted_buffer = buffer + attn->past_sequence_length * attn->max_sequence_length; - sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; - strides = {attn->max_sequence_length * attn->max_sequence_length, attn->max_sequence_length, 1}; - break; - default: - LOGS_DEFAULT(FATAL) << "unsupported mask type: " << attn->mask_type; - throw std::runtime_error("unsupported mask type"); - } - return {offseted_buffer, sizes, strides}; -} - -template -struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { - std::string Signature() const override { - return MakeString( - "B", attention->batch_size, - "_S", attention->sequence_length, - "_T", attention->total_sequence_length, - "_N", attention->num_heads, - "_H", attention->head_size, - "_Hv", attention->v_head_size, - bias_buffer != nullptr ? "_B" : "_NB", - "_M", mask_index_dims.size(), - "_QKV", attention->qkv_format, - "_MODE", attention->mode); - } - - std::tuple GetGemmsMNKOBatch() const { - ORT_ENFORCE(attention != nullptr); - auto m = attention->sequence_length; - auto n = attention->total_sequence_length; - auto k = attention->head_size; - auto o = attention->v_head_size; - auto batch = attention->batch_size * attention->num_heads; - return {m, n, k, o, batch}; - } - - hipblasHandle_t handle; - const RocmAttentionParameters* attention; - const hipDeviceProp_t* device_prop; - - float scale; - const T* q_buffer; - const T* k_buffer; - const T* v_buffer; - T* out_buffer; - - // optional, attention bias [B,N,S,T] - // TODO: support shape [B,1,S,T], [1, N, S, T], [1, 1, S, T] with broadcast. - const T* bias_buffer{nullptr}; - - // optional, mask value - const int* mask_index_buffer{nullptr}; - TensorShapeVector mask_index_dims{}; - - // optional, internal - void* workspace_buffer{nullptr}; -}; - -inline bool IsKVBNMH(AttentionMode mode) { - switch (mode) { - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - return true; - default: - return false; - } -} - -template -struct GemmSoftmaxGemmPermuteGenericPipeline { - static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams* params) { - return params->mask_index_buffer != nullptr && params->mask_index_dims.size() >= 2; - } - - static std::tuple GetWorkspacePlan(const GemmSoftmaxGemmPermuteParams* params) { - auto bytes = GetAttentionScratchSize( - sizeof(T), - params->attention->batch_size, - params->attention->num_heads, - params->attention->sequence_length, - params->attention->total_sequence_length); - auto gemm1_out = reinterpret_cast(params->workspace_buffer); - auto softmax_out = gemm1_out + (bytes / sizeof(T)); - auto gemm2_out = softmax_out + (bytes / sizeof(T)); - return {gemm1_out, softmax_out, gemm2_out}; - } - - inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { - return GetAttentionWorkspaceSize( - sizeof(T), - attn->batch_size, - attn->num_heads, - attn->head_size, - attn->sequence_length, - attn->total_sequence_length); - } - - inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams* params) { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - - int k_buffer_stride = n * k; - if (IsKVBNMH(params->attention->mode)) { - k_buffer_stride = params->attention->max_sequence_length * params->attention->head_size; - } - - // GEMM1 [m,k] * [n,k]' -> [m,n] - return blas::row_major::StridedBatchedGemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::Trans, - m, n, k, - // For raw attention mask, the scalar is moved to softmax computation. - /*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale, - params->q_buffer, k, m * k, - params->k_buffer, k, k_buffer_stride, - /*beta=*/0.0f, - gemm1_out, n, m * n, - batch); - } - - inline static Status SoftmaxRawMask(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { - // Softmax on [m,n] along the n dimension. - // Raw attention mask could be 2D (B,S) or 3D (B,S,T) or 4D(B,1,M,M), where M is the max sequence length. - auto attn = params->attention; - auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - T* persistent_softmax_workspace = gemm1_out; // replace Q*K' in place if persistent softmax is selected. - return ComputeSoftmaxWithRawMask( - params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - strides, buffer, nullptr, params->bias_buffer, gemm1_out, softmax_out, - attn->is_unidirectional, /* FIXME: this must not be attn.scale! */ params->scale, - use_persistent_softmax, persistent_softmax_workspace, attn->mask_filter_value); - } - - inline static Status Softmax1DIndexMask(const GemmSoftmaxGemmPermuteParams* params) { - auto mask_1d = params->mask_index_buffer; - auto mask_1d_size = params->mask_index_dims[0]; - // Softmax on [m,n] along the n dimension. - // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - const int* mask_start = (mask_1d_size > attn->batch_size) ? mask_1d + attn->batch_size : nullptr; - return ComputeSoftmaxWithMask1D( - params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - mask_1d, mask_start, params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); - } - - inline static Status SoftmaxNoMask(const GemmSoftmaxGemmPermuteParams* params) { - // Softmax on [m,n] along the n dimension. - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - return ComputeSoftmax( - params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, - params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); - } - - inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams* params) { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - - int v_buffer_stride = n * o; - if (IsKVBNMH(params->attention->mode)) { - v_buffer_stride = params->attention->max_sequence_length * params->attention->v_head_size; - } - - // GEMM2 [m,n] * [n,o] -> [m,o] - // semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H. - return blas::row_major::StridedBatchedGemm( - params->TuningContext(), params->Stream(), params->handle, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - m, o, n, - /*alpha=*/1.0f, - softmax_out, n, m * n, - params->v_buffer, o, v_buffer_stride, - /*beta=*/0.0f, - gemm2_out, o, m * o, - batch); - } - - inline static Status Permute0213(const GemmSoftmaxGemmPermuteParams* params) { - // Permute 0213 - // gemm2_out is B,N,S,H, transpose to out_buffer as B,S,N,H - auto attn = params->attention; - auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); - return LaunchTransCtx( - params->StreamHandle(), - attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, - params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer); - } - - static Status GetSupportedStatus(const GemmSoftmaxGemmPermuteParams* params) { - const auto& attn = params->attention; - // TODO: address the BNMH k,v strides - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: - case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: - case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: - if (attn->qkv_format == Q_K_V_BNSH) { - return Status::OK(); - } else { - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, got ", - attn->qkv_format); - } - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH but k, v are BLNH"); - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - // If sequence_length is 1, query of B1NH can be simply viewed as BN1H. - if (attn->sequence_length == 1) { - return Status::OK(); - } else { - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, ", - "only if sequence_length is 1, query of BSNH can be viewed as BNSH"); - } - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH"); - default: - return TUNABLE_OP_UNSUPPORTED("unknonw"); - } - return TUNABLE_OP_UNSUPPORTED("unknonw case"); - } - - static Status Run(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { - auto supported_status = GetSupportedStatus(params); - if (!supported_status.IsOK()) { - return supported_status; - } - ORT_RETURN_IF_ERROR(Gemm1(params)); - - if (UseRawAttentionMask(params)) { - ORT_RETURN_IF_ERROR(SoftmaxRawMask(params, use_persistent_softmax)); - } else if (params->mask_index_dims.size() == 1) { // 1d index mask - ORT_RETURN_IF_ERROR(Softmax1DIndexMask(params)); - } else { - ORT_RETURN_IF_ERROR(SoftmaxNoMask(params)); - } - - ORT_RETURN_IF_ERROR(Gemm2(params)); - ORT_RETURN_IF_ERROR(Permute0213(params)); - return Status::OK(); - } -}; - -template -class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp> { - public: - GemmSoftmaxGemmPermuteTunableOp(); - - inline static bool IsSupportedMode(const RocmAttentionParameters* attn) { - switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: - case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: - // depends on qkv format - if (attn->qkv_format == Q_K_V_BNSH || attn->qkv_format == Q_K_V_BSNH) { - return true; - } else { - return false; - } - case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: - case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: - case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: - case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: - case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: - case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: - case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: - case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: - case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: - case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: - return true; - default: - return false; - } - } - - inline static bool IsSupportedMaskType(const RocmAttentionParameters* attn) { - switch (attn->mask_type) { - case MASK_NONE: - case MASK_2D_DUMMY: - case MASK_2D_KEY_PADDING: - case MASK_3D_ATTENTION: - case MASK_4D_MEGATRON: - return true; - default: - return false; - } - } - - inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { - size_t num_bytes = GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(attn); - -#ifdef USE_COMPOSABLE_KERNEL - if (IsSupportedMaskType(attn)) { - auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn); - num_bytes = std::max(num_bytes, sizeof(T) * sizes.x * sizes.y * sizes.z); - } -#endif - - return num_bytes; - } - - template - __global__ static void ConvertToFilledMaskValue( - T* __restrict__ out, - const int3 out_strides, - const int* __restrict__ mask_buffer, - const int3 mask_lengths, // [B,S,T] - const int3 mask_strides, - Converter cvt) { - const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; - if (global_idx >= mask_lengths.x * mask_lengths.y * CeilDiv(mask_lengths.z, VecSize)) { - return; - } - - const int tidx = (global_idx % CeilDiv(mask_lengths.z, VecSize)) * VecSize; - const int bs_idx = global_idx / CeilDiv(mask_lengths.z, VecSize); - const int sidx = bs_idx % mask_lengths.y; - const int bidx = bs_idx / mask_lengths.y; - - int64_t in_offset = mask_strides.x * bidx + mask_strides.y * sidx + mask_strides.z * tidx; - int64_t out_offset = out_strides.x * bidx + out_strides.y * sidx + out_strides.z * tidx; - - if (tidx + VecSize <= mask_lengths.z) { - using LoadT = const aligned_vector; - using StoreT = aligned_vector; - LoadT load = *reinterpret_cast(mask_buffer + in_offset); - StoreT store; - -#pragma unroll - for (int i = 0; i < VecSize; i++) { - store.val[i] = cvt(load.val[i]); - } - *reinterpret_cast(out + out_offset) = store; - } else { -#pragma unroll - for (int i = 0; i < mask_lengths.z - tidx; i++) { - out[out_offset + i] = cvt(mask_buffer[in_offset + i]); - } - } - } - - static Status LaunchConvertToFilledMaskValue(const GemmSoftmaxGemmPermuteParams* params) { - constexpr const int kThreadPerBlock = 256; - constexpr const int kVecSize = 4; - - auto attn = params->attention; - auto [buffer, lengths, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); - int64_t total_threads = lengths.x * lengths.y * CeilDiv(lengths.z, kVecSize); - auto num_blocks = CeilDiv(total_threads, kThreadPerBlock); - - auto mask_filter_value = attn->mask_filter_value; - auto cvt = [=] __device__(int v) -> T { - return v == 1 ? 0 : mask_filter_value; - }; - - ConvertToFilledMaskValue<<StreamHandle()>>>( - reinterpret_cast(params->workspace_buffer), {lengths.y * lengths.z, lengths.z, 1}, // out desc - buffer, lengths, strides, // mask desc - cvt); - - return HIP_CALL(hipGetLastError()); - } -}; - -#ifdef USE_COMPOSABLE_KERNEL - -template -auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { - constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); - - using Nop = ck::tensor_operation::element_wise::PassThrough; - using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), - "attention mode is not supported, got ", params->attention->mode); - if constexpr (USE_BIAS) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer == nullptr, "biased version only support input with bias"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer != nullptr, "non-biased version only support input without bias"); - } - if constexpr (USE_MASK) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), - "mask type is not supported, got ", params->attention->mask_type); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer == nullptr, "masked version only support input with mask"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); - } - - auto attn = params->attention; - const int& G0 = attn->batch_size; - const int& G1 = attn->num_heads; - const int& M = attn->sequence_length; - const int& N = attn->total_sequence_length; - const int& K = attn->head_size; - const int& O = attn->v_head_size; - { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); - } - - auto [qs, ks, vs] = GetQkvStrides(attn); - std::vector q_buffer_lengths = {G0, G1, M, K}; - std::vector q_buffer_strides = qs.template ForBNSHCoord>(); - std::vector k_buffer_lengths = {G0, G1, N, K}; - std::vector k_buffer_strides = ks.template ForBNSHCoord>(); - std::vector v_buffer_lengths = {G0, G1, O, N}; - std::vector v_buffer_strides = vs.template ForBNHSCoord>(); - std::vector out_buffer_lengths = {G0, G1, M, O}; - std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 - - std::array bias_buffers{}; - std::array, kNumBiasBuffer> bias_lengths{}; - std::array, kNumBiasBuffer> bias_strides{}; - if constexpr (USE_BIAS) { - bias_buffers[0] = const_cast(params->bias_buffer); - bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - bias_strides[0] = {G1 * M * N, M * N, N, 1}; - } - if constexpr (USE_MASK) { - bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; - bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - if (params->mask_index_dims.size() == 2) { // [B,T] - bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; - } else if (params->mask_index_dims.size() == 3) { // [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else { - ORT_ENFORCE(false, "Unreachable"); - } - } - - auto arg = impl->MakeArgumentPointer( - params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, - bias_buffers, // Gemm1 bias, as attention mask - {}, // Gemm2 bias - q_buffer_lengths, q_buffer_strides, - k_buffer_lengths, k_buffer_strides, - v_buffer_lengths, v_buffer_strides, - out_buffer_lengths, out_buffer_strides, - bias_lengths, bias_strides, - {}, - {}, - Nop{}, - Nop{}, - Acc0ElementOp{params->scale}, - Nop{}, - Nop{}); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - - if constexpr (USE_MASK) { - ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); - } - - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); -} - -template -auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using D0DataType = typename ck::detail::tuple_concat< - std::conditional_t, ck::Tuple<>>, - std::conditional_t, ck::Tuple<>>>::type; - - constexpr static auto MaskingSpecMaskDisabled = - ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; - constexpr static auto MaskingSpecMaskOutUpperTriangle = - ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; - - std::vector>>> - ret; - - for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { - auto type_string = impl->GetTypeString(); - - auto invoker = impl->MakeInvokerPointer(); - auto op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GemmSoftmaxGemmPermuteParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); - - return GetArgAndRunInvoker(impl, invoker, params); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); - } - - for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { - auto type_string = impl->GetTypeString(); - - auto invoker = impl->MakeInvokerPointer(); - auto op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GemmSoftmaxGemmPermuteParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->attention->sequence_length != params->attention->total_sequence_length, - "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); - - return GetArgAndRunInvoker(impl, invoker, params); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); - } - - return ret; -} -#endif // USE_COMPOSABLE_KERNEL - -template -GemmSoftmaxGemmPermuteTunableOp::GemmSoftmaxGemmPermuteTunableOp() { - this->RegisterOp([](const GemmSoftmaxGemmPermuteParams* params) { - return GemmSoftmaxGemmPermuteGenericPipeline::Run(params, false); - }); - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { - this->RegisterOp(std::move(op)); - } -#endif -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h deleted file mode 100644 index 0aff519d20e99..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "contrib_ops/cpu/bert/attention_common.h" -#include "core/providers/rocm/shared_inc/rocm_utils.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - RocmTuningContext* tuning_ctx, // context for tuning - Stream* stream, // ORT Stream - hipblasHandle_t& hipblas, // hipblas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise.h b/onnxruntime/contrib_ops/rocm/bert/elementwise.h deleted file mode 100644 index 768295767835a..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise.h +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchElementwiseKernel(RocmTuningContext* tuning_ctx, Stream* stream, - const T* input, int input_length, - const T* bias, int bias_length, - T* output); - -// The following is LaunchElementwiseKernel implementation detail. Their interfaces are exposed for kernel explorer. -namespace internal { - -template -struct ElementwiseParams : OpParams { - ElementwiseParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, - const T* input, const T* bias, T* output, int input_length, int bias_length) - : OpParams(tuning_ctx, stream), - input(input), - bias(bias), - output(output), - input_length(input_length), - bias_length(bias_length) {} - - std::string Signature() const override { - std::string sig = std::to_string(input_length) + "_" + std::to_string(bias_length); - return sig; - } - - const T* input; - const T* bias; - T* output; - int input_length; - int bias_length; -}; - -template -class ElementwiseOp { - public: - Status operator()(const ElementwiseParams* params); - Status IsSupported(const ElementwiseParams* params); -}; - -template -Status ElementwiseStaticSelection(const ElementwiseParams* params); - -template -class ElementwiseTunableOp : public TunableOp> { - public: - ElementwiseTunableOp(); -}; - -} // namespace internal - -#define ELEMENTWISE_FWD_DECL(FnName, T) \ - namespace functor { \ - struct FnName; \ - } - -ELEMENTWISE_FWD_DECL(FastGeLU, float); -ELEMENTWISE_FWD_DECL(FastGeLU, double); -ELEMENTWISE_FWD_DECL(FastGeLU, half); -ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16); - -ELEMENTWISE_FWD_DECL(GeLU, float); -ELEMENTWISE_FWD_DECL(GeLU, double); -ELEMENTWISE_FWD_DECL(GeLU, half); -ELEMENTWISE_FWD_DECL(GeLU, BFloat16); - -ELEMENTWISE_FWD_DECL(ReLU, float); -ELEMENTWISE_FWD_DECL(ReLU, half); -ELEMENTWISE_FWD_DECL(ReLU, BFloat16); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh deleted file mode 100644 index 8255e70d27e48..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/tunable/util.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "contrib_ops/rocm/bert/elementwise.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace functor { - -struct FastGeLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - constexpr const float b = 0.7978845608028654f; // sqrt(2.0/M_PI) - - // const T cdf = a + a * _Tanh(in * (c * in * in + b)); - const T xb = x * T(b); - const T u = xb * T(0.044715f) * x * x + xb; - const T emu = __expf(-u - u); - const T cdf = T(1.0f) / (T(1.0f) + emu); - y = x * cdf; - } -}; - -struct GeLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - y = T(0.5f) * x * (T(1.f) + T(erf(0.70710678118f * float(x)))); - } -}; - -struct ReLU { - template - __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { - y = x >= T{} ? x : T{}; - } -}; - -} // namespace functor - -using onnxruntime::rocm::CeilDiv; -using onnxruntime::rocm::GPU_WARP_SIZE; - -template -__global__ void ElementwiseKernel( - const T* __restrict__ input, int input_length, - const T* __restrict__ bias, int bias_length, - T* __restrict__ output) { - const int idx = blockIdx.x * TPB + threadIdx.x; - Fn f{}; - - if (idx < input_length) { - const T x = input[idx] + (bias == nullptr ? T{} : bias[idx % bias_length]); - f(output[idx], x); - } -} - -template -__global__ void ElementwiseKernelVec( - const T* __restrict__ input, int input_length, - const T* __restrict__ bias, int bias_length, - T* output) { - using VecT = onnxruntime::rocm::aligned_vector; - Fn f{}; - - const int idx = (blockIdx.x * TPB + threadIdx.x) * ILP; - if (idx < input_length) { - T input_v[ILP]; - VecT* input_val = reinterpret_cast(&input_v); - *input_val = *reinterpret_cast(&input[idx]); - T output_v[ILP]; - VecT* output_val = reinterpret_cast(&output_v); - T bias_v[ILP]; - if (bias != nullptr) { - VecT* bias_val = reinterpret_cast(&bias_v); - *bias_val = *reinterpret_cast(&bias[idx % bias_length]); - } - -#pragma unroll - for (int i = 0; i < ILP; i++) { - const T x = (bias == nullptr) ? input_v[i] : (T)(input_v[i] + bias_v[i]); - f(output_v[i], x); - } - *(reinterpret_cast(&output[idx])) = *output_val; - } -} - -template -Status LaunchElementwiseKernel( - RocmTuningContext* tuning_ctx, Stream* stream, - const T* input, int input_length, - const T* bias, int bias_length, - T* output) { - internal::ElementwiseParams params(tuning_ctx, stream, input, bias, output, input_length, bias_length); - if (tuning_ctx->IsTunableOpEnabled()) { - static internal::ElementwiseTunableOp op; - return op(¶ms); - } - - return internal::ElementwiseStaticSelection(¶ms); -} - -namespace internal { - -template -Status ElementwiseOp::operator()(const ElementwiseParams* params) { - dim3 blocks(CeilDiv(params->input_length, ThreadsPerBlock * VecSize)); - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, - params->bias, params->bias_length, - params->output); - return HIP_CALL(hipGetLastError()); -} - -template -Status ElementwiseOp::IsSupported(const ElementwiseParams* params) { - // TODO(anyone): Add tail handling for FastGelu - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) || - (params->bias_length == 0 && params->input_length % VecSize == 0))); - // Avoid redundant configurations - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->input_length > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)); - - return Status::OK(); -} - -template -Status ElementwiseStaticSelection(const ElementwiseParams* params) { - constexpr int block_size = 256; - if constexpr (std::is_same_v) { - if (params->bias != nullptr) { - if (0 == (params->bias_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 - const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->bias_length % 4)) { - const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->bias_length % 2)) { - const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - } else { - if (0 == (params->input_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 - const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->input_length % 4)) { - const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else if (0 == (params->input_length % 2)) { - const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; - ElementwiseKernelVec<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - } - } else { - const int grid_size = (params->input_length + block_size - 1) / block_size; - ElementwiseKernel<<StreamHandle()>>>( - params->input, params->input_length, params->bias, params->bias_length, params->output); - } - return HIP_CALL(hipGetLastError()); -} - -template -ElementwiseTunableOp::ElementwiseTunableOp() { - this->RegisterOp(ElementwiseStaticSelection); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); - this->RegisterOp(ElementwiseOp{}); -} - -#undef ADD_OP - -} // namespace internal - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime - -#define ELEMENTWISE_KERNEL_IMPL(Fn, T) \ - namespace onnxruntime { \ - namespace contrib { \ - namespace rocm { \ - template Status LaunchElementwiseKernel( \ - RocmTuningContext * tuning_ctx, Stream* stream, \ - const T* input, int input_length, \ - const T* bias, int bias_length, \ - T* output); \ - namespace internal { \ - template class ElementwiseTunableOp; \ - } \ - } \ - } \ - } diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu deleted file mode 100644 index c2a670ea76aca..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu deleted file mode 100644 index 97f0f74640c6e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu deleted file mode 100644 index 67e50869133f5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" - -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, float); -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, half); -ELEMENTWISE_KERNEL_IMPL(functor::ReLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc deleted file mode 100644 index fdb62d3a2aec5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/gemm_fast_gelu.h" - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" -#include "core/providers/cpu/math/matmul_helper.h" -#include "core/providers/rocm/rocm_common.h" - -using onnxruntime::rocm::ToHipType; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GemmFastGelu, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - GemmFastGelu); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) - -template -Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { - typedef typename ToHipType::MappedType HipT; - - const auto* X = ctx->Input(0); - const auto* W = ctx->Input(1); - const auto* bias = ctx->Input(2); - - bool transa = false; - bool transb = false; - bool trans_batch_a = false; - bool trans_batch_b = false; - - MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(X->Shape(), W->Shape(), transa, transb, trans_batch_a, trans_batch_b, false)); - - Tensor* Y = ctx->Output(0, helper.OutputShape()); - - // Bail out early if the output is going to be empty - if (Y->Shape().Size() == 0) - return Status::OK(); - - // gemmfastgelu only support alpha == 1 and beta == 0 - const HipT alpha = ToHipType::FromFloat(1.0f); - const HipT beta = ToHipType::FromFloat(0.0f); - - using onnxruntime::rocm::tunable::blas::BlasOp; - - return blas::row_major::GemmFastGelu( - GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), - transa ? BlasOp::Trans : BlasOp::NonTrans, - transb ? BlasOp::Trans : BlasOp::NonTrans, - helper.M(), helper.N(), helper.K(), - alpha, - reinterpret_cast(X->Data()), helper.Lda(transa), - reinterpret_cast(W->Data()), helper.Ldb(transb), - (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, - beta, - reinterpret_cast(Y->MutableData()), helper.Ldc()); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h deleted file mode 100644 index ae4f84fa5f033..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using onnxruntime::rocm::RocmKernel; - -template -class GemmFastGelu final : public RocmKernel { - public: - GemmFastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {} - Status ComputeInternal(OpKernelContext* ctx) const override; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh deleted file mode 100644 index 77f53f9eed027..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#ifdef USE_COMPOSABLE_KERNEL -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" - -using onnxruntime::rocm::ToHipType; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { -namespace internal { - -#ifdef USE_COMPOSABLE_KERNEL - -using onnxruntime::rocm::CKBlasOpAdaptor; -using onnxruntime::rocm::CKDataTypeAdaptor; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using Nop = ck::tensor_operation::element_wise::PassThrough; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; -using FastGelu = ck::tensor_operation::element_wise::FastGelu; - -template -auto GetCKGemmAddFastGeluTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< - ALayout, BLayout, ck::Tuple, Row, - CKDataType, CKDataType, ck::Tuple, CKDataType, - Nop, Nop, AddFastGelu>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = onnxruntime::MakeString("withbias ", impl->GetTypeString()); - - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero || params->bias == nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); - - auto nop = Nop{}; - auto addfastgelu = AddFastGelu{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, std::array{params->bias}, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, std::array{0}, params->ldc, - nop, nop, addfastgelu); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); - } - return ret; -} - -template -auto GetCKGemmFastGeluTypeStringAndOps() { - using CKDataType = typename CKDataTypeAdaptor::type; - using ALayout = typename CKBlasOpAdaptor::type; - using BLayout = typename CKBlasOpAdaptor::type; - using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< - ALayout, BLayout, ck::Tuple<>, Row, - CKDataType, CKDataType, ck::Tuple<>, CKDataType, - Nop, Nop, FastGelu>; - using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; - - std::vector>>> ret; - for (auto&& impl : InstanceFactory::GetInstances()) { - auto type_string = onnxruntime::MakeString("nobias ", impl->GetTypeString()); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { - auto one = ToHipType::FromFloat(1.0f); - auto zero = ToHipType::FromFloat(0.0f); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->alpha != one || params->beta != zero || params->bias != nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); - - auto nop = Nop{}; - auto fastgelu = FastGelu{}; - auto arg = impl->MakeArgumentPointer(params->a, params->b, - {}, - params->c, - params->m, params->n, params->k, - params->lda, params->ldb, - {}, - params->ldc, - nop, nop, fastgelu); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); - } - return ret; -} -#else -struct Row {}; -struct Col {}; -#endif // USE_COMPOSABLE_KERNEL - -} // namespace internal -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h deleted file mode 100644 index 2b8a21b83f177..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/gemm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::blas::BlasOp; -using onnxruntime::rocm::tunable::blas::BlasOpToString; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -template -struct GemmFastGeluParams : OpParams { - std::string Signature() const override { - bool has_bias = (nullptr != bias) ? 0 : 1; - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias); - } - hipblasHandle_t handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - T alpha; - const T* a; - int64_t lda; - const T* b; - int64_t ldb; - const T* bias; - T beta; - T* c; - int64_t ldc; -}; - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu deleted file mode 100644 index 8d7e64b1015be..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#define _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES -#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" - -#include -#include - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" -#include "core/providers/rocm/shared_inc/fpgeneric.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -namespace row_major { - -template -inline GEMMFASTGELU(T, ScalarT) { - GemmFastGeluParams params; - params.tuning_ctx = tuning_ctx; - params.stream = stream; - params.handle = handle; - - params.opa = opa; - params.opb = opb; - params.m = m; - params.n = n; - params.k = k; - if constexpr (!std::is_same_v && std::is_same_v) { - params.alpha = ToHipType::FromFloat(std::forward(alpha)); - } else { - params.alpha = alpha; - } - params.a = a; - params.lda = lda; - params.b = b; - params.ldb = ldb; - params.bias = bias; - if constexpr (!std::is_same_v && std::is_same_v) { - params.beta = ToHipType::FromFloat(std::forward(beta)); - } else { - params.beta = beta; - } - params.c = c; - params.ldc = ldc; - - if (tuning_ctx->IsTunableOpEnabled()) { - if (opa == BlasOp::N && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else if (opa == BlasOp::T && opb == BlasOp::N) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else if (opa == BlasOp::N && opb == BlasOp::T) { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { - static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; - return gemm_fast_gelu(¶ms); - } - } - - return internal::GemmFastGeluUnfused(¶ms); -} - -#define CALL_GEMMFASTGELU(T, ScalarT) \ - GemmFastGelu(tuning_ctx, stream, handle, \ - opa, opb, \ - m, n, k, \ - alpha, a, lda, b, ldb, bias, \ - beta, c, ldc) - -// clang-format off -GEMMFASTGELU(float, float ) { return CALL_GEMMFASTGELU(float, float ); } -GEMMFASTGELU(half, half ) { return CALL_GEMMFASTGELU(half, half ); } -GEMMFASTGELU(BFloat16, BFloat16) { return CALL_GEMMFASTGELU(BFloat16, BFloat16); } -GEMMFASTGELU(half, float ) { return CALL_GEMMFASTGELU(half, float ); } -GEMMFASTGELU(BFloat16, float ) { return CALL_GEMMFASTGELU(BFloat16, float ); } -// clang-format on - -#undef CALL_GEMMFASTGELU - -} // namespace row_major - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h deleted file mode 100644 index b707c63ef44be..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "core/common/status.h" -#include "core/common/float16.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { - -#define GEMMFASTGELU(T, ScalarT) \ - common::Status GemmFastGelu( \ - RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ - BlasOp opa, BlasOp opb, \ - std::int64_t m, std::int64_t n, std::int64_t k, \ - ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ - const T* bias, ScalarT beta, T* c, std::int64_t ldc) - -namespace row_major { - -GEMMFASTGELU(float, float); -GEMMFASTGELU(half, half); -GEMMFASTGELU(BFloat16, BFloat16); -GEMMFASTGELU(half, float); -GEMMFASTGELU(BFloat16, float); - -} // namespace row_major - -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime - -#ifndef _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES -#undef GEMMFASTGELU -#endif diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh deleted file mode 100644 index e157aa57f8c43..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "contrib_ops/rocm/bert/elementwise.h" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" -#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" -#include "core/providers/rocm/tunable/gemm.h" -#include "core/providers/rocm/tunable/gemm_hipblaslt.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace blas { -namespace internal { - -using namespace onnxruntime::rocm::tunable::blas::internal; - -template -Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { - namespace column_major = onnxruntime::rocm::tunable::blas::column_major; - ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning_ctx, params->stream, params->handle, - params->opb, params->opa, - params->n, params->m, params->k, - params->alpha, params->b, params->ldb, params->a, params->lda, - params->beta, params->c, params->ldc)); - - int64_t fast_gelu_input_length = params->m * params->n; - int64_t bias_length = (params->bias != nullptr) ? params->n : 0; - - // Because of GemmFastGeluUnfused is a combination of GemmOp and FastGeluOp, FastGeluOp in this combination is - // an inplace computation. - // 1. If we call GemmFastGeluUnfused directly with enabled tuning, it may cause the input buffer of FastGelu been - // updated accumulatedly and result in incorrect result finally. This only happens if the tuning's FindFastest is invoked. - // 2. It's safe to call GemmFastGeluUnfused with disabled tuning, FastGelu only run once and produce correct result. - // 3. It's safe to call GemmFastGeluUnfused as part of GemmFastGeluTunableOp with enable tuning, GemmTunableOp and - // FastGeluTunableOp will do tune in first warmup step separately during GemmFastGeluUnfused profiling process. - // After that, the call to GemmFastGeluUnfused not invoke tuning's FindFastest of FastGelu. - // - // Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp - // to protect original input value. - return onnxruntime::contrib::rocm::LaunchElementwiseKernel( - params->tuning_ctx, params->Stream(), - params->c, static_cast(fast_gelu_input_length), - params->bias, static_cast(bias_length), - params->c); -} - -template -class GemmFastGeluTunableOp : public TunableOp> { - public: - GemmFastGeluTunableOp() { - this->RegisterOp(GemmFastGeluUnfused); -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - -#ifdef USE_HIPBLASLT - for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } -}; - -} // namespace internal -} // namespace blas -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu deleted file mode 100644 index 09a6550549614..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ /dev/null @@ -1,530 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/platform/env_var_utils.h" -#include "contrib_ops/rocm/bert/group_query_attention.h" -#include "contrib_ops/cpu/bert/group_query_attention_helper.h" -#include "contrib_ops/rocm/bert/rotary_embedding_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" - -#ifdef USE_COMPOSABLE_KERNEL_CK_TILE -#include "ck_tile/core/numeric/integer.hpp" -#include "fmha_fwd.hpp" -#endif - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GroupQueryAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ - .MayInplace(3, 1) \ - .MayInplace(4, 2) \ - .InputMemoryType(OrtMemTypeCPUInput, 6), \ - GroupQueryAttention); - -// REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -// REGISTER_KERNEL_TYPED(BFloat16) - -template -std::string GetCkFmhaDataTypeString(); - -template <> -std::string GetCkFmhaDataTypeString() { - return "fp16"; -} - -template <> -std::string GetCkFmhaDataTypeString() { - return "bf16"; -} - -__global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int inc) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx < num_elems) { - out[idx] = seqlens[idx] + inc; - } -} - -Status LaunchSeqlensInc(hipStream_t stream, const int* seqlens, int* out, int num_elems, int inc) { - constexpr int NumThreads = 128; - int num_blks = CeilDiv(num_elems, NumThreads); - seqlens_inc_kernel<<>>(seqlens, out, num_elems, inc); - return HIP_CALL(hipGetLastError()); -} - -__global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx < num_elems) { - out[idx] = idx * length_per_seq; - } - if (idx == 0) { - out[num_elems] = num_elems * length_per_seq; - } -} - -Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int length_per_seq) { - constexpr int NumThreads = 128; - int num_blks = CeilDiv(num_elems, NumThreads); - seqstart_init_kernel<<>>(out, num_elems, length_per_seq); - return HIP_CALL(hipGetLastError()); -} - -// Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, - const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - int b = tid / seqlen; - int s = tid % seqlen; - if (b < batch_size) { - if (s < seqlens_k[b] + 1) { - position_ids[tid] = s; - } else { - position_ids[tid] = 1; - } - } -} - -// Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid < batch_size) { - position_ids[tid] = seqlens_k[tid]; - } -} - -// Convert seqlens_k to position_ids -Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, - int64_t* position_ids, hipStream_t stream, const int max_threads_per_block) { - const int seqlen = parameters.sequence_length; - const int batch_size = parameters.batch_size; - const int threads = max_threads_per_block; - const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_first_prompt) { - SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); - } else { - SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); - } - return HIP_CALL(hipGetLastError()); -} - -template -GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) - : RocmKernel(info) { - int64_t num_heads = 0; - int64_t kv_num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); - num_heads_ = static_cast(num_heads); - kv_num_heads_ = static_cast(kv_num_heads); - is_past_bsnh_ = false; - is_unidirectional_ = true; - local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); - do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; - rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; - scale_ = info.GetAttrOrDefault("scale", 0.0f); -} - -template <> -std::once_flag GroupQueryAttention::arch_checking_{}; - -template <> -std::once_flag GroupQueryAttention::arch_checking_{}; - -template -Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { -#if USE_COMPOSABLE_KERNEL_CK_TILE - auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); - const Tensor* query = ctx->Input(0); - const Tensor* key = ctx->Input(1); - const Tensor* value = ctx->Input(2); - const Tensor* past_key = ctx->Input(3); - const Tensor* past_value = ctx->Input(4); - const Tensor* seqlens_k = ctx->Input(5); - const Tensor* total_seqlen = ctx->Input(6); - const Tensor* cos_cache = ctx->Input(7); - const Tensor* sin_cache = ctx->Input(8); - - auto& device_prop = GetDeviceProp(); - std::call_once( - arch_checking_, - [](const hipDeviceProp_t& device_prop) { - if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && - std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { - LOGS_DEFAULT(WARNING) - << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " - << "CDNA2 and CDNA3 archs."; - LOGS_DEFAULT(WARNING) - << "GroupQueryAttention running on an unsuppoted GPU may result in " - << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; - } - }, - device_prop); - - GroupQueryAttentionParameters parameters; - using HipT = typename ToHipType::MappedType; - - const int max_thr_per_blk = device_prop.maxThreadsPerBlock; - - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, - key, - value, - past_key, - past_value, - cos_cache, - sin_cache, - ¶meters, - num_heads_, - kv_num_heads_, - seqlens_k, - total_seqlen, - is_past_bsnh_, - scale_, - max_thr_per_blk)); - - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; - const int kv_num_heads = parameters.kv_num_heads; - const int head_size = parameters.head_size; - AttentionQkvFormat past_kv_format = parameters.past_kv_format; - - parameters.local_window_size = local_window_size_; - parameters.is_unidirectional = is_unidirectional_; - // parameters.zeros_count = kZerosCount; - // parameters.zero_ptr = zeros_.get(); - // parameters.left_padding = left_padding_; - parameters.do_rotary = do_rotary_; - parameters.rotary_interleaved = rotary_interleaved_; - - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( - context->OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); - - if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); - } - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(batch_size); - output_shape[1] = static_cast(sequence_length); - output_shape[2] = static_cast(parameters.hidden_size); - Tensor* output = ctx->Output(0, output_shape); - Strides output_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - - int4 past_shape; - std::vector present_dims; - Strides present_strides; - Strides past_strides; - if (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { - past_shape = { - batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size}; - past_strides = Strides::BSNHMemory( - batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size); - present_dims = { - batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size}; - present_strides = Strides::BSNHMemory( - batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); - } else { // BNSH - past_shape = { - batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size}; - past_strides = Strides::BNSHMemory( - batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size); - present_dims = { - batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size}; - present_strides = Strides::BNSHMemory( - batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size); - } - TensorShape present_shape(present_dims); - Tensor* present_key = ctx->Output(1, present_shape); - Tensor* present_value = ctx->Output(2, present_shape); - - Strides query_strides; - Strides key_strides; - Strides value_strides; - int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; // BNSH coord - const HipT* query_ptr = reinterpret_cast(query->DataRaw()); - const HipT* key_ptr; - const HipT* value_ptr; - if (!parameters.is_packed_qkv) { - query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - key_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); - value_strides = key_strides; - key_ptr = reinterpret_cast(key->DataRaw()); - value_ptr = reinterpret_cast(value->DataRaw()); - } else { - query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); - key_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); - value_strides = query_strides; - const size_t key_offset = static_cast(num_heads * head_size); - const size_t value_offset = static_cast(kv_num_heads * head_size); - key_ptr = query_ptr + key_offset; - value_ptr = key_ptr + value_offset; - } - - IAllocatorUniquePtr rotary_q_tmp; - IAllocatorUniquePtr rotary_k_tmp; - if (parameters.do_rotary) { - size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); - size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); - auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); - auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); - - rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); - rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); - auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, - reinterpret_cast(seqlens_k->DataRaw()), - reinterpret_cast(rotary_position_ids_tmp.get()), - hip_stream, max_thr_per_blk)); - // Launch rotary embedding kernel - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_q_tmp.get(), query_ptr, - reinterpret_cast(rotary_position_ids_tmp.get()), - reinterpret_cast(cos_cache->DataRaw()), - reinterpret_cast(sin_cache->DataRaw()), - parameters.batch_size, parameters.sequence_length, - parameters.num_heads, parameters.head_size, - parameters.rotary_dim, parameters.seqlen_present_kv_cache, - /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, - query_strides.ForBNSHCoord(), - rotary_q_strides.ForBNSHCoord())); - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, - reinterpret_cast(rotary_position_ids_tmp.get()), - reinterpret_cast(cos_cache->DataRaw()), - reinterpret_cast(sin_cache->DataRaw()), - parameters.batch_size, parameters.sequence_length, - parameters.kv_num_heads, parameters.head_size, - parameters.rotary_dim, parameters.seqlen_present_kv_cache, - /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, - key_strides.ForBNSHCoord(), - rotary_k_strides.ForBNSHCoord())); - query_ptr = reinterpret_cast(rotary_q_tmp.get()); - key_ptr = reinterpret_cast(rotary_k_tmp.get()); - query_strides = rotary_q_strides; - key_strides = rotary_k_strides; - } - - const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; - IAllocatorUniquePtr seqlens_k_tmp; - - // build present kv cache - auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); - auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); - if (parameters.is_first_prompt) { - // copy prompt kv to present kv - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), - present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), - present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - } else { - const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); - const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); - parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: - if (!parameters.kv_share_buffer) { - // copy past to present, - // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are - // not the same, aka, can not be as simple as strided - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_key_ptr, past_shape, past_strides.ForBNSHCoord(), - present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), - present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - } else { - // In the case of share buffer - ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); - ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); - } - // then append new kv to present - size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, - present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, - max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, - present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, - max_thr_per_blk)); - - // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. - // we should call fmha with total sequence lengths - seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); - seqlens_k_ptr = seqlens_k_tmp.get(); - } - static_assert(std::is_same_v); - - const float scale = parameters.scale == 0.0f - ? 1.f / sqrt(static_cast(parameters.head_size)) - : parameters.scale; - bias_enum bias_type = bias_enum::no_bias; - - mask_info mask = [&]() { - if (local_window_size_ != -1) { - mask_info ret; - ret.type = mask_enum::window_generic; - ret.left = local_window_size_; - ret.right = parameters.is_unidirectional ? 0 : -1; - // ret.x = kv_sequence_length - (sequence_length - ret.left); - // ret.y = sequence_length + (ret.right - kv_sequence_length); - return ret; - } - - if (parameters.is_first_prompt && is_unidirectional_) { - return mask_info::decode("t", sequence_length, kv_sequence_length); - } - - return mask_info::decode("0", sequence_length, kv_sequence_length); - }(); - - auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); - auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit( - hip_stream, seqstart_q_tmp.get(), batch_size, - query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit( - hip_stream, seqstart_k_tmp.get(), batch_size, - present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); - - fmha_fwd_args args{ - query_ptr, - present_key->DataRaw(), - present_value->DataRaw(), - nullptr, // bias, alibi/element - nullptr, // lse, logsumexp buffer - output->MutableDataRaw(), - seqstart_q_tmp.get(), // seqstart_q_ptr, for group mode - seqstart_k_tmp.get(), // seqstart_k_ptr, for group mode - seqlens_k_ptr, // seqlen_k_ptr, for group mode - sequence_length, // seqlen_q, for batch mode - kv_sequence_length, // seqlen_k, for batch mode - parameters.batch_size, // batch - parameters.sequence_length, // max_seqlen_q - parameters.head_size, // hdim_q - parameters.head_size, // hdim_v - parameters.num_heads, - parameters.kv_num_heads, - scale, - 1.0f, // scale_p of squant, useless - 1.0f, // scale_o of squant, useless - static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S - static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S - static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S - batch_size, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 - static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S - static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N - static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N - static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N - 0, // nhead_stride_bias - batch_size, // nhead_stride_lse - static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B - static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B - static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B - static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B - 0, // batch_stride_bias - num_heads * batch_size, // batch_stride_lse - static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B - mask.left, // window_size_left - mask.right, // window_size_right - static_cast(mask.type)}; - -#if 0 - std::cout - << "\n sequence_length:" << sequence_length - << "\n kv_sequence_length:" << kv_sequence_length - << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache - << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; - - std::cout - << "\n q_ptr:" << args.q_ptr - << "\n k_ptr:" << args.k_ptr - << "\n v_ptr:" << args.v_ptr - << "\n bias_ptr:" << args.bias_ptr - << "\n lse_ptr:" << args.lse_ptr - << "\n o_ptr:" << args.o_ptr - << "\n seqstart_q_ptr:" << args.seqstart_q_ptr - << "\n seqstart_k_ptr:" << args.seqstart_k_ptr - << "\n seqlen_k_ptr:" << args.seqlen_k_ptr - << "\n seqlen_q:" << args.seqlen_q - << "\n seqlen_k:" << args.seqlen_k - << "\n batch:" << args.batch - << "\n max_seqlen_q:" << args.max_seqlen_q - << "\n hdim_q:" << args.hdim_q - << "\n hdim_v:" << args.hdim_v - << "\n nhead_q:" << args.nhead_q - << "\n nhead_k:" << args.nhead_k - << "\n scale_s:" << args.scale_s - << "\n scale_p:" << args.scale_p - << "\n scale_o:" << args.scale_o - << "\n stride_q:" << args.stride_q - << "\n stride_k:" << args.stride_k - << "\n stride_v:" << args.stride_v - << "\n stride_bias:" << args.stride_bias - << "\n stride_o:" << args.stride_o - << "\n nhead_stride_q:" << args.nhead_stride_q - << "\n nhead_stride_k:" << args.nhead_stride_k - << "\n nhead_stride_v:" << args.nhead_stride_v - << "\n nhead_stride_bias:" << args.nhead_stride_bias - << "\n nhead_stride_lse:" << args.nhead_stride_lse - << "\n nhead_stride_o:" << args.nhead_stride_o - << "\n batch_stride_q:" << args.batch_stride_q - << "\n batch_stride_k:" << args.batch_stride_k - << "\n batch_stride_v:" << args.batch_stride_v - << "\n batch_stride_bias:" << args.batch_stride_bias - << "\n batch_stride_lse:" << args.batch_stride_lse - << "\n batch_stride_o:" << args.batch_stride_o - << "\n window_size_left:" << args.window_size_left - << "\n window_size_right:" << args.window_size_right - << "\n mask_type:" << args.mask_type - << std::endl; -#endif - - fmha_fwd_traits traits{ - parameters.head_size, - parameters.head_size, // v head size - GetCkFmhaDataTypeString(), - !parameters.is_first_prompt, // true, // is_group_mode - true, // is_v_rowmajor ? dim is fastest : seq is fastest - mask.type, - bias_type, - false, // has_lse - false, // do_fp8_static_quant, aka, squant - }; - - ck_tile::stream_config stream_config{ - hip_stream, - false // time_kernel - }; - - auto duration = fmha_fwd(traits, args, stream_config); - if (duration < 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "fmha_fwd internal error"); - } - HIP_RETURN_IF_ERROR(hipGetLastError()); - - return Status::OK(); -#else - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tile to be enabled"); -#endif -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h deleted file mode 100644 index ce0de1f761aa5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class GroupQueryAttention final : public RocmKernel { - public: - GroupQueryAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - int num_heads_; // number of attention heads - int kv_num_heads_; // different for k and v for group query attention - int local_window_size_; - bool is_unidirectional_; - bool is_past_bsnh_; - bool do_rotary_; - bool rotary_interleaved_; - float scale_; - - private: - static std::once_flag arch_checking_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh deleted file mode 100644 index 2eeb7c3e8f279..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh +++ /dev/null @@ -1,270 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on bert plugins in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#pragma once - -#include -#include -#include -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/shared_inc/rocm_call.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -__device__ inline T Rsqrt(const T& x); - -template <> -__device__ inline float Rsqrt(const float& x) { - return rsqrtf(x); -} - -template <> -__device__ inline half Rsqrt(const half& x) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - return hrsqrt(x); -#else - return half(rsqrtf(static_cast(x))); -#endif -} - -__device__ inline half2 AddHalf2(const half2 a, const half2 b) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - return __hadd2(a, b); -#else - return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y)); -#endif -} - -struct KeyValuePairSum { - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - return hipcub::KeyValuePair(a.key + b.key, a.value + b.value); - } - - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - const half2 a2 = __halves2half2(a.key, a.value); - const half2 b2 = __halves2half2(b.key, b.value); - const half2 res = AddHalf2(a2, b2); - return hipcub::KeyValuePair(__low2half(res), __high2half(res)); - } - - __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, - const hipcub::KeyValuePair& b) { - return hipcub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); - } -}; - -template -__device__ inline void LayerNorm( - const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = static_cast(output[idx]); - const U g = static_cast(gamma[i]); - const U b = (nullptr == beta) ? U(0.f) : static_cast(beta[i]); - output[idx] = static_cast(g * (val - mu) * rsigma + b); - } -} - -template -__device__ inline void SimplifiedLayerNorm( - const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = static_cast(output[idx]); - const U g = static_cast(gamma[i]); - output[idx] = static_cast(g * val * rsigma); - } -} - -template -__device__ inline void SimplifiedLayerNormVec( - const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - const VecV gamma_v = *reinterpret_cast(gamma + i); - VecV output_v = *reinterpret_cast(output + idx); - -#pragma unroll - for (int k = 0; k < ILP; k++) { - output_v.val[k] = U(gamma_v.val[k]) * U(output_v.val[k]) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } - } -} - -template -__device__ inline void LayerNormVec( - const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + i) : VecV(); - const VecV gamma_v = *reinterpret_cast(gamma + i); - VecV output_v = *reinterpret_cast(output + idx); - -#pragma unroll - for (int k = 0; k < ILP; k++) { - output_v.val[k] = (beta != nullptr) ? U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma + U(beta_v.val[k]) : U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } - } -} - -template -__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair& thread_data, - const int ld, const int idx, const V* beta, const V* gamma, - const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - // Small settings: the block covers the leading dimension TPB >= ld. The input - // value is available in a register - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mu; // mean - __shared__ U rsigma; // 1 / std.dev. - - KeyValuePairSum pair_sum; - const hipcub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); - - if (threadIdx.x == 0) { - mu = sum_kv.key; - rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + threadIdx.x * ILP) : VecV(); - const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); - VecV output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - output_v.val[i] = (beta != nullptr) ? U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma + U(beta_v.val[i]) : U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } -} - -template -__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const U& thread_data, const int ld, const int idx, - const V* gamma, const U epsilon, V* output) { - // Assuming thread_data is already divided by ld - // Small settings: the block covers the leading dimension TPB >= ld. The input - // value is available in a register - using VecV = aligned_vector; - using BlockReduce = hipcub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U rsigma; // 1 / std.dev. - - const U sum = BlockReduce(temp_storage).Sum(thread_data); - - if (threadIdx.x == 0) { - rsigma = Rsqrt(sum + epsilon); - } - __syncthreads(); - - if (ILP * threadIdx.x < ld) { - const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); - VecV output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - output_v.val[i] = U(gamma_v.val[i]) * U(input_v[i]) * rsigma; - } - *(reinterpret_cast(output + idx)) = output_v; - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu deleted file mode 100644 index 5d4ef53b8ba97..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/multihead_attention.h" - -#include "contrib_ops/cpu/bert/multihead_attention_helper.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" -#include "core/platform/env_var_utils.h" -#include "core/providers/rocm/rocm_common.h" - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_MHA_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - MultiHeadAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - MultiHeadAttention) - -REGISTER_MHA_KERNEL_TYPED(float); -REGISTER_MHA_KERNEL_TYPED(MLFloat16); - -static constexpr int kPastSequenceLengthInputIndex = 7; -static constexpr int kBeamWidthInputIndex = 8; -static constexpr int kPastInputIndex = 5; -static constexpr int kPresentOutputIndex = 1; - -#define REGISTER_DMMHA_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - DecoderMaskedMultiHeadAttention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(kPastInputIndex, kPresentOutputIndex) \ - .MayInplace(kPastInputIndex + 1, kPresentOutputIndex + 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \ - .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \ - MultiHeadAttention) - -REGISTER_DMMHA_KERNEL_TYPED(float); -REGISTER_DMMHA_KERNEL_TYPED(MLFloat16); - -template -MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) - : RocmKernel(info), - attn_type_(info.node().OpType() == "DecoderMaskedMultiHeadAttention" ? kDecoderMaskedMultiHeadAttention - : kMultiHeadAttention) { - int64_t num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - num_heads_ = static_cast(num_heads); - - mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - - scale_ = info.GetAttrOrDefault("scale", 0.0f); - - past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - tunable_op_ = std::make_shared(); -} - -template -Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { - ORT_ENFORCE( - GetTuningContext()->IsTunableOpEnabled(), - "MultiHeadAttention of ROCm EP is only supported if tunable op is used and tuning is enabled."); - - const Tensor* query = context->Input(0); - const Tensor* key = context->Input(1); - const Tensor* value = context->Input(2); - - const Tensor* bias{}; - const Tensor* key_padding_mask{}; - const Tensor* attention_bias{}; - const Tensor* past_key{}; - const Tensor* past_value{}; - const Tensor* past_seq_len{}; - - const Tensor* cache_indirection = nullptr; - - if (attn_type_ == kMultiHeadAttention) { - bias = context->Input(3); - key_padding_mask = context->Input(4); - attention_bias = context->Input(5); - past_key = context->Input(6); - past_value = context->Input(7); - } else if (attn_type_ == kDecoderMaskedMultiHeadAttention) { - key_padding_mask = context->Input(3); - attention_bias = context->Input(4); - past_key = context->Input(5); - past_value = context->Input(6); - past_seq_len = context->Input(kPastSequenceLengthInputIndex); - // const Tensor* beam_width = context->Input(8); // NOTE: not used - // const Tensor* cache_indirection = context->Input(9); // TODO: should not present for ROCm EP - bias = context->Input(10); - } - - if (nullptr != bias) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "qkv_bias is not supported on ROCm EP. " - "User should fuse the qkv bias to qkv projection instead."); - } - - auto& device_prop = GetDeviceProp(); - RocmAttentionParameters attn; - ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, - key, - value, - bias, - key_padding_mask, - attention_bias, - past_key, - past_value, - cache_indirection, - past_seq_len, - &attn, /* parameters */ - num_heads_, - mask_filter_value_, - scale_, - is_unidirectional_, - past_present_share_buffer_, - attn_type_, - device_prop.maxThreadsPerBlock)); - - if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); - } - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(attn.batch_size); - output_shape[1] = static_cast(attn.sequence_length); - output_shape[2] = static_cast(attn.v_hidden_size); - Tensor* output = context->Output(0, output_shape); - - std::vector present_dims{ - attn.batch_size, - attn.num_heads, - past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, - attn.head_size, - }; - TensorShape present_shape(present_dims); - Tensor* present_key = context->Output(1, present_shape); - Tensor* present_value = context->Output(2, present_shape); - - ORT_RETURN_IF_ERROR(ClassifyAttentionMode( - attn_type_, &attn, - /*qkv=*/{query, key, value}, - /*past=*/{past_key, past_value}, - /*present=*/{present_key, present_value})); - - using HipT = typename ToHipType::MappedType; - using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; - auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn); - auto workspace = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - hipStream_t stream = Stream(context); - if (nullptr != present_key) { // process past present concat - Strides dst_strides; - - int4 past_shape; - Strides past_src_strides; - const HipT* past_key_src; - const HipT* past_value_src; - HipT* past_key_dst{}; - HipT* past_value_dst{}; - - int4 add_shape; - Strides add_src_strides; - const HipT* add_key_src = reinterpret_cast(key->DataRaw()); - const HipT* add_value_src = reinterpret_cast(value->DataRaw()); - HipT* add_key_dst; - HipT* add_value_dst; - - if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH || - attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - - past_shape = {attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; - past_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); - past_key_src = reinterpret_cast(past_key->DataRaw()); - past_value_src = reinterpret_cast(past_value->DataRaw()); - past_key_dst = reinterpret_cast(present_key->MutableDataRaw()); - past_value_dst = reinterpret_cast(present_value->MutableDataRaw()); - - if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH || - attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); - - if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else if ( - attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || - attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || - attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH || - attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { - dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); - - if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH) { - add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); - } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { - add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "past present concatenation is not implemented for attention mode ", attn.mode); - } - add_shape = {attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size}; // kernel in coord (b,n,s,h) - add_key_dst = reinterpret_cast(present_key->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - add_value_dst = reinterpret_cast(present_value->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); - - if (past_key_dst) { - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, past_key_src, past_shape, past_src_strides.ForBNSHCoord(), - past_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - if (past_value_dst) { - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, past_value_src, past_shape, past_src_strides.ForBNSHCoord(), - past_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, add_key_src, add_shape, add_src_strides.ForBNSHCoord(), - add_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy( - stream, add_value_src, add_shape, add_src_strides.ForBNSHCoord(), - add_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); - } - - GemmSoftmaxGemmPermuteParams params; - params.tuning_ctx = GetTuningContext(); - params.stream = context->GetComputeStream(); - params.handle = GetHipblasHandle(context); - params.attention = &attn; - params.device_prop = &device_prop; - params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_; - std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews( - &attn, - nullptr == query ? nullptr : reinterpret_cast(query->DataRaw()), - nullptr == key ? nullptr : reinterpret_cast(key->DataRaw()), - nullptr == value ? nullptr : reinterpret_cast(value->DataRaw()), - nullptr == present_key ? nullptr : reinterpret_cast(present_key->DataRaw()), - nullptr == present_value ? nullptr : reinterpret_cast(present_value->DataRaw())); - params.out_buffer = reinterpret_cast(output->MutableDataRaw()); - - if (key_padding_mask != nullptr) { - params.mask_index_buffer = key_padding_mask->Data(); - params.mask_index_dims = key_padding_mask->Shape().AsShapeVector(); - } - - if (attention_bias != nullptr) { - params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); - } - - params.workspace_buffer = reinterpret_cast(workspace.get()); - return (*std::static_pointer_cast(tunable_op_))(¶ms); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h deleted file mode 100644 index 1d676d7a7bcac..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/rocm/bert/attention_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class MultiHeadAttention final : public RocmKernel { - public: - MultiHeadAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - AttentionType attn_type_; - int num_heads_; // number of attention heads - float mask_filter_value_; - float scale_; - bool past_present_share_buffer_{false}; - bool is_unidirectional_{false}; - - // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: - // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. - // 2. We don't want to construct the object repeatly (which is expansive) during Compute. - std::shared_ptr tunable_op_; -}; - -template -class DecoderMaskedMultiHeadAttention final : public RocmKernel { - public: - DecoderMaskedMultiHeadAttention(const OpKernelInfo& info); - Status ComputeInternal(OpKernelContext* context) const override; - - protected: - AttentionType mha_type; - int num_heads_; // number of attention heads - float mask_filter_value_; - float scale_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc deleted file mode 100644 index 9e649fb591896..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/skip_layer_norm.h" - -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - SkipLayerNormalization, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - SkipSimplifiedLayerNormalization, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -using namespace ONNX_NAMESPACE; - -template -SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { - ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); - ORT_ENFORCE(epsilon_ >= 0); -} - -template -Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { - const Tensor* input = ctx->Input(0); - const Tensor* skip = ctx->Input(1); - const Tensor* gamma = ctx->Input(2); - - const Tensor* beta = Simplified ? nullptr : ctx->Input(3); - const Tensor* bias = Simplified ? ctx->Input(3) : ctx->Input(4); - - Tensor* output = ctx->Output(0, input->Shape()); - - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); - - if (input->Shape() != skip->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "skip is expected to have same shape as input"); - } - - if (input->Shape().Size() == 0) { - return Status::OK(); - } - - const auto& input_dims = input->Shape().GetDims(); - size_t input_dims_size = input_dims.size(); - if (input_dims_size != 3 && input_dims_size != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size); - } - - int hidden_size = static_cast(input_dims[input_dims_size - 1]); - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of gamma and input does not match"); - } - - if (nullptr != beta) { - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of beta and input does not match"); - } - } - - if (nullptr != bias) { - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "bias is expected to have 1 dimension, got ", bias_dims.size()); - } - if (bias_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of bias and input does not match"); - } - } - - int64_t element_count = input->Shape().Size(); - typedef typename ToHipType::MappedType HipT; - - return LaunchSkipLayerNormKernel( - GetTuningContext(), - ctx->GetComputeStream(), - reinterpret_cast(output->MutableData()), - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, - reinterpret_cast(input->Data()), - reinterpret_cast(skip->Data()), - reinterpret_cast(gamma->Data()), - (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, - epsilon_, - hidden_size, - static_cast(element_count)); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h deleted file mode 100644 index 02228bc59cedc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class SkipLayerNorm final : public RocmKernel { - public: - SkipLayerNorm(const OpKernelInfo& op_kernel_info); - Status ComputeInternal(OpKernelContext* context) const override; - - private: - float epsilon_; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu deleted file mode 100644 index 8387c49a3310b..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu +++ /dev/null @@ -1,86 +0,0 @@ -#include "hip/hip_runtime.h" -/* - The implementation of this file is based on skipLayerNorm plugin in TensorRT demo: - https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - -Copyright 2019 NVIDIA Corporation - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// Modifications: Add SkipLayerNormKernelVec to -// leverage vectorized load/write. -// and templatize ComputeSkipLayerNorm for different -// data types. -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" - -#include - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" -#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, - const T* skip, const V* gamma, const V* beta, const T* bias, float epsilon, int ld, int element_count) { - // this must be true because element_count is the total size of the tensor - assert(element_count % ld == 0); - - SkipLayerNormParams params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, - gamma, beta, bias, epsilon, ld, element_count); - - if (tuning_ctx->IsTunableOpEnabled()) { - static SkipLayerNormTunableOp op; - return op(¶ms); - } - - return SkipLayerNormStaticSelection(¶ms); -} - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, - const float* skip, const float* gamma, const float* beta, - const float* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, - const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, - const float* skip, const float* gamma, const float* beta, - const float* bias, float epsilon, int ld, - int element_count); - -template Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, - const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, int ld, - int element_count); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h deleted file mode 100644 index 5e2a92447d2f5..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchSkipLayerNormKernel( - RocmTuningContext* tuning, - Stream* stream, - V* output, // output tensor - T* skip_input_bias_add_output, // optional output tensor - const T* input, // input tensor - const T* skip, // skip tensor - const V* gamma, // Layer normalization gamma tensor - const V* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int element_count // number of elements in input tensor -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h deleted file mode 100644 index fcfbc8969e498..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "contrib_ops/rocm/bert/layer_norm.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -T maybe2half(float x); - -template <> -float maybe2half(float x) { - return x; -} - -template <> -half maybe2half(float x) { - return __float2half_rn(x); -} - -template -__global__ void SkipLayerNormKernel( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias, - const U epsilon, V* output, T* skip_input_bias_add_output) { - const U reverse_ld = U(1.f / ld); - const int offset = blockIdx.x * ld; - - KeyValuePairSum pair_sum; - // reduce x and x^2 - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const U val = (bias == nullptr) ? static_cast(input[idx]) + static_cast(skip[idx]) : static_cast(input[idx]) + static_cast(skip[idx]) + static_cast(bias[i]); - const U rldval = reverse_ld * val; - thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); - - if (skip_input_bias_add_output != nullptr) { - skip_input_bias_add_output[idx] = static_cast(val); - } - - output[idx] = static_cast(val); - } - - if constexpr (Simplified) { - SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); - return; - } - - LayerNorm(thread_data, ld, offset, beta, gamma, epsilon, output); -} - -// Vectorized kernel -template -__global__ void SkipLayerNormKernelVec( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, - const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput) { - const U reverse_ld = U(1.f / ld); - const int offset = blockIdx.x * ld; - - KeyValuePairSum pair_sum; - // reduce x and x^2 - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - using VecT = aligned_vector; - using VecV = aligned_vector; - if (threadIdx.x * ILP < ld) { - for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { - int idx = offset + i; - - const VecT input_v = *reinterpret_cast(input + idx); - const VecT skip_v = *reinterpret_cast(skip + idx); - const VecT bias_v = hasBias ? *reinterpret_cast(bias + i) : VecT(); - VecT skip_input_bias_add_output_v, output_v; - -#pragma unroll - for (int k = 0; k < ILP; k++) { - const U val = hasBias ? static_cast(input_v.val[k]) + static_cast(skip_v.val[k]) + static_cast(bias_v.val[k]) : static_cast(input_v.val[k]) + static_cast(skip_v.val[k]); - const U rldval = reverse_ld * val; - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v.val[k] = static_cast(val); - } - thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); - output_v.val[k] = static_cast(val); - } - - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; - } - - *(reinterpret_cast(output + idx)) = output_v; - } - } - - if constexpr (Simplified) { - SimplifiedLayerNormVec(thread_data.value, ld, offset, gamma, epsilon, output); - return; - } - - LayerNormVec(thread_data, ld, offset, beta, gamma, epsilon, output); -} - -// Vectorized kernel -template -__global__ void SkipLayerNormKernelSmall( - const int ld, const T* input, const T* skip, const V* beta, const V* gamma, - const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput) { - const U rld = U(1.f / ld); - const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld - - using VecT = aligned_vector; - hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); - - VecT input_v; - if (ILP * threadIdx.x < ld) { - input_v = *reinterpret_cast(input + idx); - const VecT skip_v = *reinterpret_cast(skip + idx); - const VecT bias_v = hasBias ? *reinterpret_cast(bias + threadIdx.x * ILP) : VecT(); - VecT skip_input_bias_add_output_v; - - U rldval_sum = U(0.f); - U rldvalsq_sum = U(0.f); -#pragma unroll - for (int i = 0; i < ILP; i++) { - const U val = hasBias ? static_cast(input_v.val[i]) + static_cast(skip_v.val[i]) + static_cast(bias_v.val[i]) : static_cast(input_v.val[i]) + static_cast(skip_v.val[i]); - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v.val[i] = static_cast(val); - } - - const U rldval = rld * val; - rldval_sum += rldval; - rldvalsq_sum += rldval * val; - input_v.val[i] = static_cast(val); - } - - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; - } - - thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); - } - - if constexpr (Simplified) { - SimplifiedLayerNormSmall(input_v.val, thread_data.value, ld, idx, gamma, epsilon, output); - return; - } - - LayerNormSmall(input_v.val, thread_data, ld, idx, beta, gamma, epsilon, output); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h deleted file mode 100644 index 0391704ce1c56..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include - -#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::CeilDiv; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -struct SkipLayerNormParams : OpParams { - SkipLayerNormParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, - const T* skip, const V* gamma, const V* beta, - const T* bias, float epsilon, int ld, int element_count) - : OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {} - - std::string Signature() const override { - std::string sig = std::to_string(ld) + "_" + std::to_string(element_count); - return sig; - } - - V* output; - T* skip_input_bias_add_output; - const T* input; - const T* skip; - const V* gamma; - const V* beta; - const T* bias; - float epsilon; - int ld; - int element_count; -}; - -template -Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { - // Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels, - // which could offer better performance in some combinations of ThreadsPerBlock and VecSize. - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->ld <= 8192 && params->ld % VecSize == 0 && - params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))); - SkipLayerNormKernelSmall<<element_count, params->ld)), - dim3(ThreadsPerBlock), - 0, params->StreamHandle()>>>( - params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, - (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); - return HIP_CALL(hipGetLastError()); -} - -template -Status SkipLayerNormRegularOp(const SkipLayerNormParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !((params->ld > 0 && params->ld % VecSize == 0 && - (params->ld >= ThreadsPerBlock * VecSize || - (params->ld < GPU_WARP_SIZE && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))))); - SkipLayerNormKernelVec<<element_count, params->ld)), - dim3(ThreadsPerBlock), - 0, params->StreamHandle()>>>( - params->ld, params->input, params->skip, - params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, - (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); - return HIP_CALL(hipGetLastError()); -} - -template -Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { - bool hasBias = (params->bias == nullptr) ? false : true; - bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true; - const int grid_size = params->element_count / params->ld; - const int block_size = 256; - -#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \ - if (params->ld <= ELEMENTS) { \ - SkipLayerNormKernelSmall<<StreamHandle()>>>( \ - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \ - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, \ - hasBias, hasSkipInputBiasAdditionOutput); \ - break; \ - } - if (0 == (params->ld % 4)) { - do { - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 32, 2) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 32, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 96, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(768, 192, 4) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(1024, 256, 4) - - SkipLayerNormKernel<<StreamHandle()>>>( - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); - } while (0); - } else { - do { - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 64, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 128, 1) - LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 384, 1) - - SkipLayerNormKernel<<StreamHandle()>>>( - params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, - static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); - } while (0); - } - return HIP_CALL(hipPeekAtLastError()); -} // namespace rocm - -#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); \ - this->RegisterOp(name); - -#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 512) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 576) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 640) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 704) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 768) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 832) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 1024) - -template -class SkipLayerNormTunableOp : public TunableOp> { - public: - SkipLayerNormTunableOp() { - this->RegisterOp(SkipLayerNormStaticSelection); - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmallOp) - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegularOp) - - // NOTE: the 1st kernel is SkipLayerNorm Original implementation. - this->SetDefaultId(0); - } -}; - -#undef ADD_OP_FOR_ALL_VEC_SIZE -#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc b/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc deleted file mode 100644 index 6ae8d1202d462..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#include -#include "core/providers/shared_library/provider_api.h" // Include this otherwise Windows build complains Env::Default() missing -#include "core/platform/env_var_utils.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -// The environment variable is for testing purpose only, and it might be removed in the future. -// If you need some option in production, please file a feature request. -constexpr const char* kTransformerOptions = "ORT_TRANSFORMER_OPTIONS"; - -// Initialize the singleton instance -TransformerOptions TransformerOptions::instance; - -const TransformerOptions* TransformerOptions::GetInstance() { - if (!instance.initialized_) { - // We do not use critical section here since it is fine to initialize multiple times by different threads. - int value = ParseEnvironmentVariableWithDefault(kTransformerOptions, 0); - instance.Initialize(value); - - if (value > 0) - std::cout << "ORT_TRANSFORMER_OPTIONS: IsPrecisionMode=" << instance.IsPrecisionMode() - << ",DisablePersistentSoftmax=" << instance.DisablePersistentSoftmax() - << ",DisableHalf2=" << instance.DisableHalf2() - << std::endl; - } - - return &instance; -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.h b/onnxruntime/contrib_ops/rocm/bert/transformer_common.h deleted file mode 100644 index 6816b5b9d07ec..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/transformer_common.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -class TransformerOptions { - public: - static const TransformerOptions* GetInstance(); - - bool IsPrecisionMode() const { return is_precision_mode_; } - - bool DisablePersistentSoftmax() const { return disable_persistent_softmax_; } - - bool DisableHalf2() const { return disable_half2_; } - - void Initialize(int value) { - is_precision_mode_ = (value & 0x01) > 0; - disable_persistent_softmax_ = (value & 0x02) > 0; - disable_half2_ = (value & 0x04) > 0; - initialized_ = true; - } - - private: - // Default is false. If the mode is on, prefer precision than speed. - bool is_precision_mode_{false}; - - // Disable persistent softmax. - bool disable_persistent_softmax_{false}; - - // Disable half2 kernel. - bool disable_half2_{false}; - - bool initialized_{false}; - - static TransformerOptions instance; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh deleted file mode 100644 index d0a0d09fcbae3..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#ifdef USE_COMPOSABLE_KERNEL -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#endif // USE_COMPOSABLE_KERNEL - -#include "contrib_ops/rocm/diffusion/group_norm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#ifdef USE_COMPOSABLE_KERNEL - -using onnxruntime::rocm::CKDataTypeAdaptor; - -// The SiLU function is a special case of Swish function, -// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: -// SiLU(x) = x * sigmoid(x) -// Swish(x) = x * sigmoid(bx) -// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. -using Silu = ck::tensor_operation::element_wise::Swish; -using Pass = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -template -auto GetCKGroupNormNHWCTypeStringAndOps() { - using XDataType = typename CKDataTypeAdaptor::type; - using YDataType = typename CKDataTypeAdaptor::type; - using SaveMeanInvStdDataType = typename CKDataTypeAdaptor::type; - using GammaDataType = float; - using BetaDataType = float; - - using Activation = std::conditional_t; - - std::vector>>> ret; - for (auto&& impl : internal::GetDeviceGroupNormInstances()) { - std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; - auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; - auto invoker = impl->MakeInvokerPointer(); - - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( - const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), - "Input skip or bias is not supported by composable kernel."); - if constexpr (WithSilu) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->use_silu, "Silu version only support groupnorm with silu"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->use_silu, "Pass version only support groupnorm without silu"); - } - std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, - params->c, params->channels_per_group, 1}; - std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; - std::vector reduce_dims{1, 2, 4}; - - auto activation = Activation{}; - - auto arg = impl->MakeArgumentPointer(in_lengths, // lengths - in_out_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - in_out_strides, // yStrides - {0, 0}, // saveMeanStrides - {0, 0}, // saveInvStdStrides - reduce_dims, // reduceDims - params->epsilon, - params->src, - params->gamma, - params->beta, - params->dst, - nullptr, - nullptr, - activation); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_group_norm_op))); - } - return ret; -} -#endif // USE_COMPOSABLE_KERNEL - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh deleted file mode 100644 index 68f7d47282845..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ /dev/null @@ -1,130 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#ifdef USE_COMPOSABLE_KERNEL -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" -#include "ck/utility/data_type.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -using F16 = ck::half_t; -using F32 = float; - -using Silu = ck::tensor_operation::element_wise::Swish; -using Pass = ck::tensor_operation::element_wise::PassThrough; - -using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface -using ck::tensor_operation::device::DeviceNormalizationFwdImpl; // the implementation - -// See https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1fefd82ed8/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp - -template -using device_normalization_f32_instances = std::tuple< - // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl - // clang-format on - >; - -template -using device_normalization_f16_instances = - // clang-format off - std::tuple < - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl - // clang-format on - >; - -// Use this function to get implementation -template -std::vector>> -GetDeviceGroupNormInstances() { - return {}; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Silu, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Pass, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Silu, 5, 3>(); - -template <> -std::vector>> -GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Pass, 5, 3>(); - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu deleted file mode 100644 index ad191314e5e4c..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f16_instances{}); - - return instances; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f16_instances{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu deleted file mode 100644 index ceb53ed442abc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef USE_COMPOSABLE_KERNEL -#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace onnxruntime { -namespace contrib { -namespace rocm { -namespace internal { - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f32_instances{}); - - return instances; -} - -template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, - device_normalization_f32_instances{}); - - return instances; -} - -} // namespace internal -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime -#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h deleted file mode 100644 index 7cff640db2f34..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" -#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { - GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, - onnxruntime::Stream* ort_stream, - T* output, - T* add_out, - const T* input, - const T* skip, - const T* bias, - const float* gamma, - const float* beta, - float* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_silu, - bool broadcast_skip, - int channels_per_block) - : OpParams(tuning_ctx, ort_stream), - GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, - num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} - - std::string Signature() const override { - std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; - std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; - std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; - std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; - std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + - std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + - skip_suffix + broadcast_suffix + bias_suffix; - return sig; - } -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu deleted file mode 100644 index 142aaf14e8d2d..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// The ROCM kernel is hipified from CUDA kernel. -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" - -#include -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_tunable_op.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -template -Status LaunchGroupNormKernel( - RocmTuningContext* tuning_ctx, - Stream* ort_stream, - T* output, - T* add_out, - const T* input, - const T* skip, - const T* bias, - const float* gamma, - const float* beta, - void* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_silu, - bool broadcast_skip, - int channels_per_block) { - GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, - reinterpret_cast(workspace), epsilon, batch_size, num_channels, - height, width, num_groups, use_silu, broadcast_skip, channels_per_block); - - if (params.channels_per_block % params.channels_per_group != 0 || - params.channels_per_block > kMaxSize || - (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "GroupNorm in ROCM does not support the input: n=", batch_size, - " h=", height, - " w=", width, - " c=", num_channels, - " groups=", num_groups); - } - - HIP_RETURN_IF_ERROR(hipMemsetAsync( - params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); - - if (tuning_ctx->IsTunableOpEnabled()) { - static GroupNormNHWCTunableOp op; - return op(¶ms); - } - - return GroupNormNHWCStaticSelection(¶ms); -} - -template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, - half* add_out, const half* input, const half* skip, const half* bias, - const float* gamma, const float* beta, void* workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block); - -template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, - float* add_out, const float* input, const float* skip, const float* bias, - const float* gamma, const float* beta, void* workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, - bool use_silu, bool broadcast_skip, int channels_per_block); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh deleted file mode 100644 index c6ca16bfdfc80..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "core/providers/rocm/triton_kernel.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#ifdef USE_TRITON_KERNEL - -namespace { - -template -std::string GetGroupNormTritonGroupName() { - std::string ret = "GroupNormTriton_"; - std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; - ret += silu_suffix; - ret += GetDataTypeName(); - return ret; -} - -} // namespace - -template -auto GetTritonGroupNormNHWCTypeStringAndOps() { - std::vector>>> ret; - auto group_name = GetGroupNormTritonGroupName(); - auto* kernel_list = GetOrtTritonKernelByGroup(group_name); - if (kernel_list == nullptr) { - return ret; - } - - for (auto i : *kernel_list) { - // Check params match - auto* metadata = GetOrtTritonKernelMetadata(i); - auto block_size = metadata->constants.at("BLOCK_SIZE"); - auto hw_size = metadata->constants.at("HW_SIZE"); - auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", - params->channels_per_group, ")."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); - if constexpr (WithSilu) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); - } - // Construct args for launch kernel - struct { - const void* src; - const void* skip; - const void* bias; - void* out; - void* add_out; - const void* gamma; - const void* beta; - int hw; - int c; - int c_per_group; - float eps; - bool has_skip; - bool has_bias; - bool broadcast_skip; - } args = { - (const void*)params->src, - (const void*)params->skip, - (const void*)params->bias, - (void*)params->dst, - (void*)params->skip_workspace, - (const void*)params->gamma, - (const void*)params->beta, - params->hw, - params->c, - params->channels_per_group, - params->epsilon, - params->skip != nullptr, - params->bias != nullptr, - params->broadcast_skip, - }; - - // Grid dim is (batch_count, groups, 1) - return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); - }; - ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); - } - return ret; -} - -#endif // USE_TRITON_KERNEL - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py deleted file mode 100644 index 5ba96ebc117f0..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ /dev/null @@ -1,135 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from itertools import product - -import triton -import triton.language as tl - - -@triton.jit -def group_norm_kernel( - input_ptr, - skip_ptr, - bias_ptr, - output_ptr, - add_out_ptr, - gamma_ptr, - beta_ptr, - img_size, - c, - c_per_group, - eps, - has_skip, - has_bias, - broadcast_skip, - BLOCK_SIZE: tl.constexpr, - HW_SIZE: tl.constexpr, - ACTIVATION_SILU: tl.constexpr, -): - row_x = tl.program_id(0) - row_y = tl.program_id(1) - stride = img_size * c - input_ptr += row_x * stride + row_y * c_per_group - output_ptr += row_x * stride + row_y * c_per_group - gamma_ptr += row_y * c_per_group - beta_ptr += row_y * c_per_group - - cols = tl.arange(0, BLOCK_SIZE) - hw = tl.arange(0, HW_SIZE) - offsets = hw[:, None] * c + cols[None, :] - mask = (cols < c_per_group)[None, :] - - bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - if has_skip: - add_out_ptr += row_x * stride + row_y * c_per_group - if broadcast_skip: - broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group - bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) - else: - skip_ptr += row_x * stride + row_y * c_per_group - if has_bias: - bias_ptr += row_y * c_per_group - bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) - - # Calculate mean and variance - _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) - _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) - for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c - a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - if has_skip and not broadcast_skip: - s_ptr = skip_ptr + i * HW_SIZE * c - s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - a += s - if has_bias or broadcast_skip: - a += bias - _sum += a - _square_sum += a * a - if has_skip: - add_y_ptr = add_out_ptr + i * HW_SIZE * c - tl.store(add_y_ptr + offsets, a, mask=mask) - - # Set axis=None (or leave it unspecified) to reduce all axes. - # TODO: In older Triton we have to reduce an axis at a time, but in our case - # for some configs it may have some issue when reducing sequentially along the axes. - group_mean = tl.sum(_sum, axis=None) / (img_size * c_per_group) - group_var = tl.sum(_square_sum, axis=None) / (img_size * c_per_group) - group_mean * group_mean - - rstd = 1 / tl.sqrt(group_var + eps) - - # Normalize and apply linear transformation - gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) - beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) - for i in range(tl.cdiv(img_size, HW_SIZE)): - y_ptr = output_ptr + i * HW_SIZE * c - if has_skip: - add_y_ptr = add_out_ptr + i * HW_SIZE * c - x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - else: - x_ptr = input_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - x_hat = (x - group_mean) * rstd - y = x_hat * gamma + beta - if ACTIVATION_SILU: - y *= tl.sigmoid(y) - tl.store(y_ptr + offsets, y, mask=mask) - - -# We can have more combinations of blocks and hw_sizes, e.g., -# blocks = [16, 32, 64, 128, 256, 512] -# hw_sizes = [8, 16, 32, 64, 128, 256, 512] -# but this will result in too many functions and slow down the compilation. -with_silu = [True, False] -dtypes = ["fp32", "fp16"] -blocks = [16, 32, 64, 128] -hw_sizes = [8, 16, 32, 64, 128, 256] -warps = [1, 2, 4, 8, 16] -name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" -group_pattern = "GroupNormTriton_{}_{}" - - -def get_function_table(): - func_table = [] - - for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): - silu_suffix = "Silu" if silu else "Pass" - name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) - group = group_pattern.format(silu_suffix, dtype) - sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) - kwargs = { - "num_warps": warp, - "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, - } - func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} - func_table.append(func_desc) - return func_table - - -if __name__ == "__main__": - func_table = get_function_table() - for func_desc in func_table: - print(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h deleted file mode 100644 index e6831f764b418..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh" -#include "contrib_ops/rocm/diffusion/group_norm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh" -#include "contrib_ops/rocm/diffusion/group_norm_triton.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using onnxruntime::rocm::GPU_WARP_SIZE; - -template -void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { - dim3 grid; - - // The number of blocks to compute all the channels. - grid.x = DivUp(params->c, params->channels_per_block); - // The number of blocks to compute all the activations in a given instance. - grid.y = DivUp(params->hw, params->hw_per_block); - // The number of instances. - grid.z = params->n; - -#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ - GroupNormNHWCSumKernel \ - <<StreamHandle()>>>( \ - params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ - params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ - params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ - break; - - // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. - switch (params->threads_per_block) { - case 256: - LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) - case 192: - LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) - case 160: - LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) - case 128: - LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) - case 64: - LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) - default: - ORT_NOT_IMPLEMENTED("Not implemented"); - } -} - -template -Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { - dim3 grid; - grid.x = DivUp(params->c, params->channels_per_block); - grid.y = DivUp(params->hw, params->hw_per_block); - grid.z = params->n; - - GroupNormNHWCSumKernel - <<StreamHandle()>>>( - params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, - params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, - params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); - return HIP_CALL(hipGetLastError()); -} - -template -void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { - dim3 grid; - - // The number of blocks to compute all the channels. - grid.x = DivUp(params->c, params->channels_per_block); - // The number of blocks to compute all the activations in a given instance. - grid.y = DivUp(params->hw, params->hw_per_block); - // The number of instances. - grid.z = params->n; - -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - GroupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ - params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ - params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ - params->hw, params->hw_per_block, params->use_silu); \ - break; - - // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. - switch (params->threads_per_block) { - case 256: - LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) - case 192: - LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) - case 160: - LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) - case 128: - LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) - case 64: - LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) - default: - ORT_NOT_IMPLEMENTED("Not implemented"); - } -} - -template -Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { - dim3 grid; - grid.x = DivUp(params->c, params->channels_per_block); - grid.y = DivUp(params->hw, params->hw_per_block); - grid.z = params->n; - - GroupNormNHWCScaleKernel - <<StreamHandle()>>>( - params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, - params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, - params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, - params->use_silu); - return HIP_CALL(hipGetLastError()); -} - -template -class GroupNormNHWCOp { - public: - Status operator()(const GroupNormNHWCTunableParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, - 0, - GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), - params->StreamHandle())); - auto status = GroupNormNHWCSumOp(params); - ORT_RETURN_IF_ERROR(status); - HIP_RETURN_IF_ERROR(hipGetLastError()); - status = GroupNormNHWCScaleOp(params); - ORT_RETURN_IF_ERROR(status); - HIP_RETURN_IF_ERROR(hipGetLastError()); - return Status::OK(); - } - - Status IsSupported(const GroupNormNHWCTunableParams* params) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), - "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, - ") isn't divisible by the number of vector size: ", VecSize); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && - params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), - "Configuration: Threads (", ThreadsPerBlock, "), vector size (", - VecSize, ") is redundant for the number of channels per group: ", - params->channels_per_block); - - return Status::OK(); - } -}; - -template -Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, - 0, - GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), - params->StreamHandle())); - GroupNormNHWCSum(params); - HIP_RETURN_IF_ERROR(hipGetLastError()); - GroupNormNHWCScale(params); - HIP_RETURN_IF_ERROR(hipGetLastError()); - return Status::OK(); -} - -#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ - this->RegisterOp(name{}); \ - this->RegisterOp(name{}); \ - this->RegisterOp(name{}); - -#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ - ADD_OP_FOR_ALL_VEC_SIZE(name, 320) - -template -class GroupNormNHWCTunableOp : public TunableOp> { - public: - GroupNormNHWCTunableOp() { - this->RegisterOp(GroupNormNHWCStaticSelection); - ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) - -#ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif // USE_COMPOSABLE_KERNEL - -#ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#endif - } -}; - -#undef ADD_OP_FOR_ALL_VEC_SIZE -#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc deleted file mode 100644 index 35427a02c631d..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/nn/conv.h" - -using namespace onnxruntime::rocm; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - NhwcConv, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/fused_conv.cc b/onnxruntime/contrib_ops/rocm/fused_conv.cc deleted file mode 100644 index 4f3be98d97f80..0000000000000 --- a/onnxruntime/contrib_ops/rocm/fused_conv.cc +++ /dev/null @@ -1,439 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include "core/common/status.h" -#include "core/providers/rocm/nn/conv.h" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -namespace { - -// Copied from hipDNN/library/src/hcc_detail/hipdnn_miopen.cpp -miopenStatus_t _miopenAddTensor( - miopenHandle_t handle, - const void* alpha, - const miopenTensorDescriptor_t aDesc, - const void* A, - const void* beta, - const miopenTensorDescriptor_t cDesc, - void* C, - const void* zero_scalar) { - const miopenTensorOp_t tensorOp = miopenTensorOpAdd; - // Using miopenOpTensor to implement Add operator. - // opnd2 = Add ( 0.0 * opnd0, alpha * opnd1 ) + beta * opnd2 - return miopenOpTensor(handle, tensorOp, - zero_scalar, cDesc, C, - alpha, aDesc, A, - beta, cDesc, C); -} - -} // namespace - -template -struct FNVHash { - uint32_t GetValue() const { return value_; } - - void Hash(const void* in_ptr, size_t nbytes) { - auto ptr = reinterpret_cast(in_ptr); - for (size_t i = 0; i < nbytes; ++i) { - value_ ^= ptr[i]; - value_ *= PRIME; - } - } - - template ::value, size_t>::type = 0> - FNVHash& operator<<(const T& pod) { - Hash(&pod, sizeof(pod)); - return *this; - } - - template - FNVHash& operator<<(const std::vector& pod_array) { - for (const auto& pod : pod_array) { - (*this) << pod; - } - return *this; - } - - void HashTensor(miopenTensorDescriptor_t tdesc) { - int size = 0; - miopenGetTensorDescriptorSize(tdesc, &size); - (*this) << size; - std::vector dims(size); - std::vector strides(size); - miopenDataType_t dtype; - miopenGetTensorDescriptor(tdesc, &dtype, dims.data(), strides.data()); - (*this) << dtype; - (*this) << dims; - (*this) << strides; - } - - void HashConvolutionDescriptor(miopenConvolutionDescriptor_t cdesc) { - int spatial_dim = 1; -#if ROCM_VERSION >= 50500 - MIOPEN_CALL(miopenGetConvolutionSpatialDim(cdesc, &spatial_dim)); - std::vector pads{spatial_dim}; - std::vector strides{spatial_dim}; - std::vector dilations{spatial_dim}; - miopenConvolutionMode_t mode; - MIOPEN_CALL(miopenGetConvolutionNdDescriptor(cdesc, spatial_dim, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)); -#else - // Previous versions of MIOpen doesn't provide API to probe the dimension of a - // miopenConvolutionDescriptor_t, so we have to guess. - // This algorithm is based on a specific behavior of miopenGetConvolutionNdDescriptor, - // which fails when requestedSpatialDim > the convolution's spatial dimension - constexpr const int kMaxSpatialDim = 5; - std::vector pads{kMaxSpatialDim}; - std::vector strides{kMaxSpatialDim}; - std::vector dilations{kMaxSpatialDim}; - miopenConvolutionMode_t mode; - bool spatial_dim_guessed = false; - for (int i = 0; i < kMaxSpatialDim; i++) { - if (miopenStatusSuccess == miopenGetConvolutionNdDescriptor( - cdesc, i, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)) { - spatial_dim_guessed = true; - break; - } - } - ORT_ENFORCE(spatial_dim_guessed, "Failed to guess the actual spatial dimension"); - // Remove the extra dimension - pads.resize(spatial_dim); - strides.resize(spatial_dim); - dilations.resize(spatial_dim); -#endif - (*this) << spatial_dim; - (*this) << pads; - (*this) << strides; - (*this) << dilations; - (*this) << mode; - } - - private: - uint32_t value_ = BASIS; -}; - -template -class FusedConv : public onnxruntime::rocm::Conv { - public: - using Base = onnxruntime::rocm::Conv; - FusedConv(const OpKernelInfo& info) : onnxruntime::rocm::Conv(info) { - std::string activation; - ORT_THROW_IF_ERROR(info.GetAttr("activation", &activation)); - ORT_THROW_IF_ERROR(MapMode(activation)); - MIOPEN_CALL_THROW(miopenCreateActivationDescriptor(&activation_desc_)); - MIOPEN_CALL_THROW(miopenSetActivationDescriptor(activation_desc_, activation_mode_, 0.0, 0.0, 0.0)); - MIOPEN_CALL_THROW(miopenCreateOperatorArgs(&fusion_args_)); - } - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(FusedConv); - - ~FusedConv() { - if (activation_desc_) { - MIOPEN_CALL_THROW(miopenDestroyActivationDescriptor(activation_desc_)); - activation_desc_ = nullptr; - } - - if (fusion_args_) { - miopenDestroyOperatorArgs(fusion_args_); - } - } - - Status ComputeInternal(OpKernelContext* context) const override { - std::lock_guard lock(Base::s_.mutex); - - ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); - if (Base::s_.Y->Shape().Size() == 0) { - return Status::OK(); - } - - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - auto factory = [this](FusedConvFusionData& fusion) { - return this->DoCreateFusionDesc(this->Node().Name(), fusion); - }; - auto& cached_item = plan_cache_.FindOrCreateFusionPlanCache(Hash(), - factory); - bool should_try_fusion_api = cached_item.Validate(this->GetMiopenHandle(context)); - - typedef typename onnxruntime::rocm::ToHipType::MappedType HipT; - const auto alpha = onnxruntime::rocm::Consts::One; - const auto beta = onnxruntime::rocm::Consts::Zero; - IAllocatorUniquePtr workspace = Base::GetWorkSpace(context->GetComputeStream()); - miopenStatus_t fusion_status = miopenStatusNotInitialized; - - if (should_try_fusion_api) { - auto& fusion_info = *cached_item.fusion; - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsConvForward(fusion_args_, - fusion_info.conv_op, - &alpha, - &beta, - Base::s_.w_data)); - if (has_z) { - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, - fusion_info.bias_z_op, - &alpha, - &beta, - Base::s_.z_data)); - } - if (has_b) { - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, - fusion_info.bias_b_op, - &alpha, - &beta, - Base::s_.b_data)); - } - if (activation_desc_) { - const float relu_notused = 0.0; - MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsActivForward(fusion_args_, - fusion_info.act_op, - &alpha, - &beta, - relu_notused, - relu_notused, - relu_notused)); - } - fusion_status = miopenExecuteFusionPlan(this->GetMiopenHandle(context), - fusion_info.plan, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.y_tensor, - Base::s_.y_data, - fusion_args_); - } - if (miopenStatusSuccess != fusion_status) { - MIOPEN_RETURN_IF_ERROR(miopenConvolutionForward(this->GetMiopenHandle(context), - &alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.fwd_algo, - &beta, - Base::s_.y_tensor, - Base::s_.y_data, - workspace.get(), - Base::s_.workspace_bytes)); - if (has_b) { - MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), - &alpha, Base::s_.b_tensor, Base::s_.b_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data, - &beta)); - } - if (has_z) { - MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), - &alpha, Base::s_.z_tensor, Base::s_.z_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data, - &beta)); - } - MIOPEN_RETURN_IF_ERROR(miopenActivationForward(this->GetMiopenHandle(context), - activation_desc_, - &alpha, - Base::s_.y_tensor, - Base::s_.y_data, - &beta, - Base::s_.y_tensor, - Base::s_.y_data)); - } - if (Base::s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(onnxruntime::rocm::SliceOutUnwantedOutputSection( - this->Stream(context), - Base::s_.y_data, - Base::s_.y_dims_with_adjusted_pads, - Base::s_.Y->MutableDataRaw(), - Base::s_.y_dims.GetDims(), - Base::s_.slice_starts, - Base::s_.slice_ends, - Base::s_.slice_axes, - Base::s_.element_size)); - } - return Status::OK(); - } - - private: - Status MapMode(const std::string& activaton_mode) { - if (activaton_mode == "Relu") { - activation_mode_ = miopenActivationMode_t::miopenActivationRELU; - } else { - return ORT_MAKE_STATUS( - StatusCategory::ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, - "unsupported conv activation mode \"", activaton_mode, "\""); - } - return Status::OK(); - } - miopenActivationMode_t activation_mode_; - miopenActivationDescriptor_t activation_desc_ = nullptr; - - miopenOperatorArgs_t fusion_args_ = nullptr; - - // MIOpen Fusion API - // TODO: create one fusion descriptor shared by multiple FusedConv - // objects - // - // Considerations: - // How to determine two FusedConv objects may share the same fusion - // descriptor? Hashing x_tensor,conv_desc, etc.? - struct FusedConvFusionData { - miopenFusionPlanDescriptor_t plan = nullptr; - miopenFusionOpDescriptor_t conv_op = nullptr; - miopenFusionOpDescriptor_t bias_b_op = nullptr; - miopenFusionOpDescriptor_t bias_z_op = nullptr; - miopenFusionOpDescriptor_t act_op = nullptr; - - // TODO: There is a potential problem. miopenHandle_t may be destroyed and - // re-created later, sharing the same address. Currently there is any way - // to detect it? - mutable std::unordered_set compiled_on; - - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FusedConvFusionData); - - FusedConvFusionData() {} - ~FusedConvFusionData() { - if (plan) { - miopenDestroyFusionPlan(plan); - } - } - }; - - struct FusionPlanCacheItem { - std::unique_ptr fusion; - Status creation_result; - // TODO: Add a timestamp for eviction - // std::chrono::time_point last_access; - - FusionPlanCacheItem() {} - - miopenStatus_t CompileOnHandle(miopenHandle_t handle) const { - if (!fusion->plan) { - return miopenStatusNotInitialized; - } - auto iter = fusion->compiled_on.find(handle); - if (iter != fusion->compiled_on.end()) { - return miopenStatusSuccess; - } - auto ret = miopenCompileFusionPlan(handle, fusion->plan); - if (miopenStatusSuccess == ret) { - fusion->compiled_on.insert(handle); - } else { - return ret; - } - return miopenStatusSuccess; - } - - bool Validate(miopenHandle_t handle) const { - if (Status::OK() != creation_result) { - return false; - } - if (!fusion || !fusion->plan) { - return false; - } - auto compiling_status = CompileOnHandle(handle); - if (miopenStatusSuccess != compiling_status) { - return false; - } - - return true; - } - }; - - struct FusionPlanCache { - mutable std::mutex mutex; - using HashKey = uint32_t; - std::unordered_map cache_directory_; - - FusionPlanCache() { - } - - FusionPlanCacheItem& FindOrCreateFusionPlanCache(HashKey key, - std::function factory) { - std::lock_guard lock(mutex); - auto iter = cache_directory_.find(key); - if (iter == cache_directory_.end()) { - cache_directory_[key].fusion = std::make_unique(); - cache_directory_[key].creation_result = factory(*cache_directory_[key].fusion); - if (Status::OK() != cache_directory_[key].creation_result) { - cache_directory_[key].fusion.reset(); - } - } - return cache_directory_[key]; - } - }; - - static FusionPlanCache plan_cache_; - - Status DoCreateFusionDesc(const std::string& node_name, FusedConvFusionData& fusion) const { - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - MIOPEN_RETURN_IF_ERROR(miopenCreateFusionPlan(&fusion.plan, - miopenVerticalFusion, - Base::s_.x_tensor)); - auto status = miopenCreateOpConvForward(fusion.plan, &fusion.conv_op, Base::s_.conv_desc, Base::s_.w_desc); - if (status == miopenStatusUnsupportedOp) { - auto msg = MakeString("MIOpen does not support the conv fusion for node \"", - node_name, "\", fallback to unfused implementation."); - LOGS_DEFAULT(WARNING) << msg; - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, msg); - } - MIOPEN_RETURN_IF_ERROR(status); - - if (has_z) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, - &fusion.bias_z_op, - Base::s_.z_tensor)); - } - if (has_b) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, - &fusion.bias_b_op, - Base::s_.b_tensor)); - } - if (activation_desc_) { - MIOPEN_RETURN_IF_ERROR(miopenCreateOpActivationForward(fusion.plan, - &fusion.act_op, - activation_mode_)); - } - return Status::OK(); - } - - uint32_t Hash() const { - FNVHash hash; - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - hash.HashTensor(Base::s_.x_tensor); - hash.HashConvolutionDescriptor(Base::s_.conv_desc); - hash.HashTensor(Base::s_.w_desc); - if (has_z) { - hash.HashTensor(Base::s_.z_tensor); - } - if (has_b) { - hash.HashTensor(Base::s_.b_tensor); - } - if (activation_desc_) { - hash << static_cast(activation_mode_); - } - return hash.GetValue(); - } -}; - -template -typename FusedConv::FusionPlanCache FusedConv::plan_cache_; - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - FusedConv, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - FusedConv); - -REGISTER_KERNEL_TYPED(float); -REGISTER_KERNEL_TYPED(MLFloat16); -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu deleted file mode 100644 index 3539f32252944..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/common.h" -#include "core/common/float16.h" -#include "core/providers/rocm/rocm_kernel.h" -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; -using namespace onnxruntime::rocm::tunable::blas; - -class GemmFloat8 final : public RocmKernel { - public: - GemmFloat8(const OpKernelInfo& info) : RocmKernel(info) { - transA_ = info.GetAttrOrDefault("transA", 0); - transB_ = info.GetAttrOrDefault("transB", 0); - dtype_ = info.GetAttrOrDefault("dtype", onnx::TensorProto_DataType_FLOAT16); - alpha_ = info.GetAttrOrDefault("alpha", 1); - beta_ = info.GetAttrOrDefault("beta", 0); - } - Status ComputeInternal(OpKernelContext* ctx) const override; - - private: -#if !defined(DISABLE_FLOAT8_TYPES) - template - Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const; - template - Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const; - - template - [[nodiscard]] inline auto* GetOp() const { - using OpT = GemmFloat8TunableOp; - if (tunable_op_) { - return static_cast(tunable_op_.get()); - } - - auto create = std::make_unique(); // avoid new - tunable_op_ = std::shared_ptr(create.release(), [](void* ptr) { - auto release = std::unique_ptr(); // avoid delete - release.reset(static_cast(ptr)); - }); - - return static_cast(tunable_op_.get()); - } -#endif - - float alpha_; - float beta_; - bool transA_; - bool transB_; - int64_t dtype_; - - // fully type erased - mutable std::shared_ptr tunable_op_; -}; - -Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { -#if defined(DISABLE_FLOAT8_TYPES) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DISABLE_FLOAT8_TYPES"); -#else - const Tensor* A = ctx->Input(0); - const Tensor* B = ctx->Input(1); - const Tensor* C = ctx->Input(2); // bias - const Tensor* scale_a = ctx->Input(3); - const Tensor* scale_b = ctx->Input(4); - const Tensor* scale_y = ctx->Input(5); - - auto a_shape = A->Shape(); - auto b_shape = B->Shape(); - ORT_ENFORCE(a_shape.NumDimensions() == 2); - ORT_ENFORCE(b_shape.NumDimensions() == 2); - - auto m = !transA_ ? a_shape[0] : a_shape[1]; - auto k = !transA_ ? a_shape[1] : a_shape[0]; - ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatible - auto n = !transB_ ? b_shape[1] : b_shape[0]; - - TensorShapeVector output_shape = {m, n}; - Tensor* Y = ctx->Output(0, output_shape); - - ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose"); - ORT_ENFORCE(dtype_ == onnx::TensorProto_DataType_FLOAT16, "ROCm GemmFloat8 only supports output float16"); - ORT_ENFORCE(C == nullptr, "ROCm GemmFloat8 does not support bias input"); - ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling"); - - if (A->IsDataType()) { - return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); - } else if (A->IsDataType()) { - return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); - } else if (B->IsDataType()) { - return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); - } else if (B->IsDataType()) { - return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); - } - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unhandled type combination of GemmFloat8"); -#endif -} - -#if !defined(DISABLE_FLOAT8_TYPES) -template -Status GemmFloat8::ComputeFp8Fp16Fp16( - OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const { - ORT_ENFORCE(A->IsDataType() && scale_a->IsDataType() && B->IsDataType()); - - onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; - params.tuning_ctx = GetTuningContext(); - params.stream = ctx->GetComputeStream(); - params.handle = GetHipblasHandle(ctx); - params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - - params.m = m; - params.n = n; - params.k = k; - - params.a = static_cast(A->DataRaw()); - params.lda = transA_ ? m : k; - params.scale_a = alpha_; - params.scale_a_dev = static_cast(scale_a->DataRaw()); - - params.b = static_cast(B->DataRaw()); - params.ldb = transB_ ? k : n; - params.scale_b = 1.0f; // NOTE: not used - params.scale_b_dev = nullptr; // NOTE: not used - - params.c = static_cast(C->MutableDataRaw()); - params.ldc = n; - params.scale_c = 1.0f; // NOTE: not implemented - params.scale_c_dev = nullptr; // NOTE: not implemented - - if (!transA_ && !transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && !transB_) { - ORT_NOT_IMPLEMENTED("transA is not implemented"); - } else if (!transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transB is not implemented"); - } else if (transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); -} - -template -Status GemmFloat8::ComputeFp16Fp8Fp16( - OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, - const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const { - ORT_ENFORCE(A->IsDataType() && B->IsDataType() && scale_b->IsDataType()); - - onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; - params.tuning_ctx = GetTuningContext(); - params.stream = ctx->GetComputeStream(); - params.handle = GetHipblasHandle(ctx); - params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; - - params.m = m; - params.n = n; - params.k = k; - - params.a = static_cast(A->DataRaw()); - params.lda = transA_ ? m : k; - params.scale_a = 1.0f; // NOTE: not used - params.scale_a_dev = nullptr; // NOTE: not used - - params.b = static_cast(B->DataRaw()); - params.ldb = transB_ ? k : n; - params.scale_b = alpha_; - params.scale_b_dev = static_cast(scale_b->DataRaw()); - - params.c = static_cast(C->MutableDataRaw()); - params.ldc = n; - params.scale_c = 1.0f; // NOTE: not implemented - params.scale_c_dev = nullptr; // NOTE: not implemented - - if (!transA_ && !transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && !transB_) { - ORT_NOT_IMPLEMENTED("transA is not implemented"); - } else if (!transA_ && transB_) { - return (*GetOp())(¶ms); - } else if (transA_ && transB_) { - ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); -} -#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() -#else -#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() -#endif - -ONNX_OPERATOR_KERNEL_EX( - GemmFloat8, - kMSDomain, - 1, - kRocmExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) - .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) - .TypeConstraint("TR", BuildKernelDefConstraints()) - .TypeConstraint("TS", BuildKernelDefConstraints()), - GemmFloat8); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh deleted file mode 100644 index b545eb1f2a149..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#if defined(USE_COMPOSABLE_KERNEL) - -#include "core/providers/rocm/composable_kernel_common.h" - -#include "ck/ck.hpp" -#include "ck/utility/functional3.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#endif - -#if !defined(DISABLE_FLOAT8_TYPES) -#include "core/common/float8.h" -#endif -#include "core/providers/rocm/tunable/gemm_common.h" - -namespace onnxruntime { -namespace rocm { -namespace tunable { - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -constexpr bool always_false = false; - -template -struct Scale { - constexpr const static bool is_pack2_invocable = true; - constexpr const static bool is_pack4_invocable = true; - - explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {} - - template - __forceinline__ __host__ __device__ Y fast_type_convert(X x) const { - static_assert(always_false, "not implemented"); - (void)x; - } - - template <> - __forceinline__ __host__ __device__ ck::half_t fast_type_convert(ck::f8_t x) const { - // https://github.com/ROCmSoftwarePlatform/triton/blob/0cc3f8b84a16892396f6e08a04991034d67e32b1/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L220-L233 - constexpr const uint16_t mask = 0x7fff; - constexpr const uint16_t sign_mask = 0x8000; - constexpr const uint16_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x2000; - } else if constexpr (std::is_same_v) { - return 0x1c00; - } - }(); - - uint8_t x_u8 = reinterpret_cast(x); - uint16_t x_u16 = static_cast(x_u8) << 8; - uint16_t exp = (x_u16 & mask) >> 1; - uint16_t y = (x_u16 & sign_mask) | (exp + exp_compensate); - return reinterpret_cast(y); - } - - __forceinline__ __host__ __device__ void operator()(ck::half_t& y, const ck::f8_t& x) const { - float scale = scale_value_ * (*dev_scale_ptr_); - y = ck::type_convert(scale * fast_type_convert(x)); - } - - __forceinline__ __host__ __device__ void operator()(ck::half2_t& ys, const ck::f8x2_t& xs) const { - float scale = scale_value_ * (*dev_scale_ptr_); - constexpr const uint32_t mask = 0x7fff7fff; - constexpr const uint32_t sign_mask = 0x80008000; - constexpr const uint32_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x20002000; - } else if constexpr (std::is_same_v) { - return 0x1c001c00; - } - }(); - - const uchar2& x2_u8 = reinterpret_cast(xs); - uchar4 x{0, x2_u8.x, 0, x2_u8.y}; - uint32_t x_u32 = reinterpret_cast(x); - - uint32_t exp = (x_u32 & mask) >> 1; - uint32_t v = (x_u32 & sign_mask) | (exp + exp_compensate); - ys = scale * reinterpret_cast(v); - } - - __forceinline__ __host__ __device__ void operator()(ck::half4_t& ys, const ck::f8x4_t& xs) const { - float scale = scale_value_ * (*dev_scale_ptr_); - constexpr const uint32_t mask = 0x7fff7fff; - constexpr const uint32_t sign_mask = 0x80008000; - constexpr const uint32_t exp_compensate = []() { - if constexpr (std::is_same_v) { - return 0x20002000; - } else if constexpr (std::is_same_v) { - return 0x1c001c00; - } - }(); - - uint32_t xs_u32 = reinterpret_cast(xs); - uint32_t x_u32_0 = __byte_perm(xs_u32, 0, 0x1504); - uint32_t x_u32_1 = __byte_perm(xs_u32, 0, 0x3726); - uint32_t exp_0 = (x_u32_0 & mask) >> 1; - uint32_t exp_1 = (x_u32_1 & mask) >> 1; - uint32_t v_0 = (x_u32_0 & sign_mask) | (exp_0 + exp_compensate); - uint32_t v_1 = (x_u32_1 & sign_mask) | (exp_1 + exp_compensate); - uint64_t v = v_0 | uint64_t(v_1) << 32; - ys = scale * reinterpret_cast(v); - } - - float scale_value_; - const float* const dev_scale_ptr_; -}; -#endif - -namespace blas { - -template -struct GemmFloat8Params : tunable::OpParams { - std::string Signature() const override { - return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); - } - - hipblasHandle_t handle; - BlasOp opa; - BlasOp opb; - int64_t m; - int64_t n; - int64_t k; - float scale_a{}; - const float* scale_a_dev{}; - const TA* a; - int64_t lda; - float scale_b{}; - const float* scale_b_dev{}; - const TB* b; - int64_t ldb; - TC* c; - float scale_c{}; - const float* scale_c_dev{}; - int64_t ldc; -}; - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using Nop = ck::tensor_operation::element_wise::PassThrough; - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, Nop, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, Nop, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, Nop>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, Nop>>>& instances); - -template -auto CreateOp(float scale, const float* dev_scale) { - if constexpr (std::is_same_v) { - return Scale(scale, dev_scale); - } else if constexpr (std::is_same_v) { - return Scale(scale, dev_scale); - } else { - return Nop{}; - } -} - -template -auto GetCKF8SplitKGemmTypeStringAndOps() { - using CKTA = typename CKDataTypeAdaptor::type; - using CKTB = typename CKDataTypeAdaptor::type; - using CKTC = typename CKDataTypeAdaptor::type; - - using CKLayoutA = typename CKBlasOpAdaptor::type; - using CKLayoutB = typename CKBlasOpAdaptor::type; - - using OpA = std::conditional_t, Scale, Nop>; - using OpB = std::conditional_t, Scale, Nop>; - using OpC = std::conditional_t, Scale, Nop>; - - using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< - CKLayoutA, CKLayoutB, Row, - CKTA, CKTB, CKTC, - OpA, OpB, OpC>; - - std::vector>>> ret; - - for (auto num_split : {1, 4, 16, 64}) { - std::vector> instances{}; - if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances); - } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances); - } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && - std::is_same_v && std::is_same_v) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances); - } else { - static_assert(always_false, "no instances for the type combination"); - LOGS_DEFAULT(FATAL) << "no instances for the type combination"; - } - for (auto&& impl : instances) { - auto type_string = std::to_string(ret.size()) + "_" + impl->GetTypeString() + "_SplitK" + std::to_string(num_split); - auto invoker = impl->MakeInvokerPointer(); - auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmFloat8Params* params) -> Status { - OpA op_a = CreateOp(params->scale_a, params->scale_a_dev); - OpB op_b = CreateOp(params->scale_b, params->scale_b_dev); - OpC op_c = CreateOp(params->scale_c, params->scale_c_dev); - - auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, - params->m, params->n, params->k, - params->lda, params->ldb, params->ldc, - op_a, op_b, op_c, num_split); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); - }; - ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); - } - } - return ret; -} - -#endif // USE_COMPOSABLE_KERNEL - -template -class GemmFloat8TunableOp : public TunableOp> { - public: - GemmFloat8TunableOp() { -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps()) { - ORT_UNUSED_PARAMETER(_); - this->RegisterOp(std::move(op)); - } -#else - ORT_ENFORCE(false, "CK is required to support GemmFloat8 computing"); -#endif // USE_COMPOSABLE_KERNEL - } -}; - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu deleted file mode 100644 index 4c691dd18f2e9..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { - -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -namespace internal { -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); -} - -namespace internal { -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances); - -// TODO: The first try of derivation does not going well due to various constraints. -// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( -// std::vector, PassThrough, PassThrough>>>& instances); - -// TODO: The first try of derivation does not going well due to various constraints. -// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( -// std::vector, PassThrough, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, PassThrough, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); - // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: -} - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector, PassThrough, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); - // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: -} - -namespace internal { -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances); - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances); -} // namespace internal - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( - std::vector, PassThrough>>>& instances) { - internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); -} - -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu deleted file mode 100644 index 49463e58886f8..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu +++ /dev/null @@ -1,97 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// The derived version is simply double BBlockTransferSrcScalarPerVector and adjust other values correspondingly -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 8, 4, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 8, 4, 32, 32, 3, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 8, 4, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 12, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 16, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 8, 4, 32, 32, 3, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 8, 4, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu deleted file mode 100644 index 236e5555051fc..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu +++ /dev/null @@ -1,80 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu deleted file mode 100644 index 1a0d45df82a71..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( - std::vector, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu deleted file mode 100644 index a0628802ec09e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu +++ /dev/null @@ -1,97 +0,0 @@ -// SPDX-License-Identifier: MIT -// Modifications Copyright (c) Microsoft. -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" - -namespace onnxruntime { -namespace rocm { -namespace tunable { -namespace blas { -namespace internal { - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; - -template -using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); -} - -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( - std::vector, PassThrough, PassThrough>>>& instances) { - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); - ck::tensor_operation::device::instance::add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); -} - -} // namespace internal -} // namespace blas -} // namespace tunable -} // namespace rocm -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc deleted file mode 100644 index 7dbb24463961e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ /dev/null @@ -1,347 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" - -using namespace onnxruntime::common; - -namespace onnxruntime { -namespace contrib { -namespace rocm { -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasAdd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Irfft); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMulConj); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskBiasDropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, NGramRepeatBlock); - -// These ops were experimental ops in onnx domain which have been removed now. We add them here as -// contrib ops to maintain backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LongformerAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LongformerAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Sampling); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_float, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, double_double_double, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_MLFloat16, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Inverse); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MatMulNBits); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); -// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GemmFloat8); - -#ifdef ENABLE_ATEN -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); -#endif - -#ifdef ENABLE_TRAINING_OPS -// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or -// 2). this is needed by inference for other purpose. -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather); -#endif - -#ifdef ORT_USE_NCCL -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll); -#endif - -template <> -KernelCreateInfo BuildKernelCreateInfo() { - KernelCreateInfo info; - return info; -} - -// clang-format off -Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { - static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // These ops were experimental ops in onnx domain which have been removed now. We add them here as - // contrib ops to maintain backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // TransposedMatMul is still here for backward compatibility - BuildKernelCreateInfo, // backward compatibility - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - -#ifdef ENABLE_ATEN - BuildKernelCreateInfo, -#endif - -#ifdef ENABLE_TRAINING_OPS - // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or - // 2). this is needed by inference for other purpose. - BuildKernelCreateInfo, -#endif - -#ifdef ORT_USE_NCCL - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, -#endif - - }; - - for (auto& function_table_entry : function_table) { - KernelCreateInfo info = function_table_entry(); - if (info.kernel_def != nullptr) { // filter disabled entries where type is void - ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); - } - } - - return Status::OK(); -} -// clang-format on - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h deleted file mode 100644 index db9a5d4fcd83e..0000000000000 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -Status RegisterRocmContribKernels(KernelRegistry& kernel_registry); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index a5ab63d74df24..130dd0c25a880 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -165,7 +165,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let query_pos = m + local_id.y + past_sequence_length;\n" << " let key_pos = n + local_id.x;\n" << " if (key_pos > query_pos) {\n" - << " sum = -3.40282e+38; // Set to very negative value for masking\n" + << " sum = -3.4028234663852886e+38; // Set to very negative value for masking\n" << " }\n"; } @@ -272,7 +272,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let effective_seq_length = seq_causal_length;\n"; } shader.MainFunctionBody() - << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "var thread_max_vector = f32_val_t(-3.4028234663852886e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" << " let actual_pos = local_offset + i + start_offset;\n" << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" @@ -289,7 +289,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } else if (use_smooth_softmax_) { shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n"; } else { - shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n"; + shader.MainFunctionBody() << "var max_value = f32(-3.4028234663852886e+38f);\n"; } shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 606dbfde15c2c..2a67dfdb07912 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -421,7 +421,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co indirect_buffer_ptr, tile_size)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr)); } if (parameters.sequence_length_ > 1) { @@ -571,8 +571,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {head_size_vec}, - {half_rotary_embedding_dim_vec}, + {static_cast(head_size_vec)}, + {static_cast(half_rotary_embedding_dim_vec)}, {present_sequence_length}, {tile_size}, {static_cast(dispatch_size)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index a5922ec9512fd..ff8e4ecc08bab 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -26,7 +26,7 @@ fn get_total_sequence_length() -> u32 { #if is_fp16 const min_value = q_element_t(-65504.0); #else -const min_value = q_element_t(-3.402823e+38f); +const min_value = q_element_t(-3.4028234663852886e+38f); #endif // For max performance max_k_step should be the same as sg_size, however we might run out of registers diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template index c6f768beffa0f..ac9a157492007 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template @@ -93,7 +93,7 @@ $MAIN { if (local_idx == 0u) { // Calculate the max and sum in current split. - var l_max = f32(-3.402823e+38f); + var l_max = f32(-3.4028234663852886e+38f); var l_sum = f32(0); for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { l_max = max(l_max, f32(tile_qk[i])); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template index 37cf7e8f11b1f..a113e96130985 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template @@ -54,7 +54,7 @@ $MAIN { // Calculate the global max and sum in qk. if (head_idx < uniforms.num_heads) { - var g_max = f32(-3.402823e+38f); + var g_max = f32(-3.4028234663852886e+38f); var g_sum = f32(0); for (var i = 0u; i < num_total_seq_length_tile; i++) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 05717fd2fe686..416a895e61745 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -128,8 +128,8 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {head_size_vec}, - {half_rotary_embedding_dim_vec}, + {static_cast(head_size_vec)}, + {static_cast(half_rotary_embedding_dim_vec)}, {static_cast(dispatch_size)}, }) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); diff --git a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template index 1214777009a8d..6e0d4c7299793 100644 --- a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template @@ -18,7 +18,7 @@ const K: u32 = k; #if is_fp16 const MAX_FLOAT: f16 = 65504.0; #else -const MAX_FLOAT: f32 = 3.402823466e+38; +const MAX_FLOAT: f32 = 3.4028234663852886e+38; #endif var shared_vals: array; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index e77496b6e8196..1c80d83f99feb 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -499,8 +499,7 @@ class PlannerImpl { /*! \brief Given a tensor-type, return the size of an element of the tensor. */ static size_t GetElementSize(const DataType& tensor_type) { - const TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); - MLDataType ml_data_type = DataTypeImpl::TypeFromProto(type_proto); + MLDataType ml_data_type = DataTypeImpl::GetDataType(*tensor_type); const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); ORT_ENFORCE(nullptr != tensor_type_base); MLDataType elt_type = tensor_type_base->GetElementType(); diff --git a/onnxruntime/core/framework/ort_value_name_idx_map.h b/onnxruntime/core/framework/ort_value_name_idx_map.h index 76e7e369514d4..6035dc4e85242 100644 --- a/onnxruntime/core/framework/ort_value_name_idx_map.h +++ b/onnxruntime/core/framework/ort_value_name_idx_map.h @@ -33,7 +33,7 @@ class OrtValueNameIdxMap { common::Status GetIdx(std::string_view name, int& idx) const { idx = -1; - auto it = map_.find(std::string(name)); + auto it = map_.find(name); if (it == map_.end()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not find OrtValue with name '", name, "'"); } diff --git a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h index bc52a45adfd43..94ef87fb069af 100644 --- a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h +++ b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h @@ -83,7 +83,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nchw_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input if (rank < 3) { - fail_shape_inference("Output tensor must have at least 3 dimensions"); + *nhwc_tp.mutable_tensor_type()->mutable_shape() = nchw_shape; + return; } // Convert output shape from N, C, H {, W, ...} to N, H {, W, ...}, C. @@ -105,8 +106,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nhwc_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input. if (rank < 3) { - fail_shape_inference( - "Tensor must have at least 3 dimensions to convert between channels first and channels last."); + *nchw_tp.mutable_tensor_type()->mutable_shape() = nhwc_shape; + return; } // Convert input shape from {N, H, W, ..., C} to {N, C, H, W, ...}. diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 6cbbdd4e0a7ef..1eb03af3befa4 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -81,6 +81,10 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(captureState); } +void Telemetry::LogCompileModel(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); +} + void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { ORT_UNUSED_PARAMETER(session_id); diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index b60345e1b8a80..9c2859f7634b6 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -66,6 +66,8 @@ class Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const; + virtual void LogCompileModel(uint32_t session_id) const; + virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 2e5d334856278..693e265af46b1 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -334,6 +334,20 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio } } +void WindowsTelemetry::LogCompileModel(uint32_t session_id) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "CompileModel", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId")); +} + void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { if (global_register_count_ == 0 || enabled_ == false) diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 261d14a7fed8c..044feec071223 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -59,6 +59,8 @@ class WindowsTelemetry : public Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const override; + void LogCompileModel(uint32_t session_id) const override; + void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index ef977161bcc37..26144e6ba3995 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -126,7 +126,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation -JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.4028234663852886e+38f, max, -3.4028234663852886e+38f) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index e2a8005aba1da..d148c4191d5d7 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1407,9 +1407,30 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs; + + // These maps store the inputs and outputs of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map fused_inputs, fused_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map fused_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + std::unordered_set erased; + + // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', + // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. + // Items added earlier receive a smaller order index than items added later. + // When constructing the final sub_graph's input or output lists, entries with smaller + // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -1428,7 +1449,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -1443,7 +1464,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -1464,39 +1485,33 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - if (node_set.find(node_idx) != node_set.end()) { - const auto& iter = fused_inputs.find(output); - if (iter != fused_inputs.end()) { - fused_inputs.erase(iter); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } - fused_outputs[output] = output_order++; - } - } else { - fused_outputs_to_add[output] = output_order++; + + if (node_set.find(node_idx) == node_set.end()) { + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + fused_outputs_to_add.insert({output, output_order++}); } } - } else { - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } - // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list - else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } + } - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - fused_outputs[output] = output_order++; - } + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + fused_outputs.insert({output, output_order++}); } } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } } } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 4d183b95bd938..0bb3accb4d754 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -76,6 +76,9 @@ Status BaseOpBuilder::ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper, return CheckHtpDataTypes(input_qnn_dtypes, output_qnn_dtypes); } else if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) { return CheckGpuDataTypes(input_qnn_dtypes, output_qnn_dtypes); + } else if (IsIrBackend(qnn_model_wrapper.GetQnnBackendType())) { + // TODO: CheckIrDataTypes + return Status::OK(); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Only support backend: CPU, HTP and GPU"); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index f3d81d7d2fdd7..9f28e2609faa1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -574,6 +574,10 @@ bool QnnOpConfigWrapper::CreateQnnGraphOp(const QNN_INTERFACE_VER_TYPE& qnn_inte return true; } +bool IsIrBackend(QnnBackendType backend_type) { + return backend_type == QnnBackendType::SERIALIZER; +} + bool IsNpuBackend(QnnBackendType backend_type) { return backend_type == QnnBackendType::HTP || backend_type == QnnBackendType::DSP; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 42f4d7bb60f34..77508f3934a20 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -96,6 +96,8 @@ enum class QnnBackendType : uint8_t { SERIALIZER, }; +bool IsIrBackend(QnnBackendType backend_type); + bool IsCpuBackend(QnnBackendType backend_type); bool IsNpuBackend(QnnBackendType backend_type); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 85901ab6fdfec..8973a4efa8ba1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -222,14 +222,14 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name(); + const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name(); + const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index cd0c0e4bffdb5..e5b48da33fbc3 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2035,9 +2035,30 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs; + + // These maps store the inputs and outputs of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map fused_inputs, fused_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map fused_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + std::unordered_set erased; + + // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', + // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. + // Items added earlier receive a smaller order index than items added later. + // When constructing the final sub_graph's input or output lists, entries with smaller + // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -2056,7 +2077,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -2071,7 +2092,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -2092,39 +2113,33 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - if (node_set.find(node_idx) != node_set.end()) { - const auto& iter = fused_inputs.find(output); - if (iter != fused_inputs.end()) { - fused_inputs.erase(iter); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } - fused_outputs[output] = output_order++; - } - } else { - fused_outputs_to_add[output] = output_order++; + + if (node_set.find(node_idx) == node_set.end()) { + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + fused_outputs_to_add.insert({output, output_order++}); } } - } else { - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } - // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list - else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } + } - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - fused_outputs[output] = output_order++; - } + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + fused_outputs.insert({output, output_order++}); } } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } } } diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc index 85096d0e262d7..9948069c6779b 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc @@ -78,8 +78,8 @@ bool ClipOpBuilder::HandleBuildOp(vsi::npu::GraphEP* graph_ep, LOGS_DEFAULT(INFO) << "Creating Clip Op."; if (node_unit.SinceVersion() <= 6) { NodeAttrHelper helper(node_unit.GetNode()); - auto min = helper.Get("min", -3.402e+38f); - auto max = helper.Get("max", 3.402e+38f); + auto min = helper.Get("min", -3.4028234663852886e+38f); + auto max = helper.Get("max", 3.4028234663852886e+38f); auto op = graph_ep->GetGraph()->CreateOperation(min, max); (*op).BindInputs(inputs).BindOutputs(outputs); graph_ep->GetOps().push_back(std::move(op)); diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index b3eb4b5061423..3e1b87821fe2f 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -13,7 +13,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), + WebGpuDevice, OrtMemTypeDefault)), buffer_manager_{buffer_manager}, mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 7c38b4557e078..74b3d669fcf3b 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -11,6 +11,11 @@ namespace webgpu { class BufferManager; +inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, + OrtDevice::MemType::DEFAULT, + OrtDevice::VendorIds::NONE, + 0}; + class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index ebe71c6ccfacd..d1a2011c8e191 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -6,22 +6,25 @@ namespace onnxruntime { namespace webgpu { -ComputeContext::ComputeContext(OpKernelContext& kernel_context, - const OpKernel& op_kernel, - const WebGpuExecutionProvider& ep, - WebGpuContext& webgpu_context) + +ComputeContextBase::ComputeContextBase(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel) : webgpu_context_{webgpu_context}, - kernel_context_{kernel_context}, - op_kernel_{op_kernel}, - ep_{ep} { + ep_{ep}, + op_kernel_{op_kernel} { } -const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const ComputeContext& context) { +const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(const ComputeContextBase& context) { return context.ep_.BufferManager(); } -const SplitKConfig& ComputeContext::GetSplitKConfig() { - return webgpu_context_.GetSplitKConfig(); +ComputeContext::ComputeContext(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel, + OpKernelContext& kernel_context) + : ComputeContextBase(webgpu_context, ep, op_kernel), + kernel_context_{kernel_context} { } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index ed16f2f0a1345..fdf89854469d6 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -24,7 +24,13 @@ namespace webgpu { class WebGpuContext; class BufferManager; -class ComputeContext final { +// +// Class ComputeContextBase is designed to provide basic context information +// for running a compute shader program. +// +// An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created. +// +class ComputeContextBase { public: // Nested accessor class to provide controlled access to BufferManager class BufferManagerAccessor { @@ -34,18 +40,31 @@ class ComputeContext final { friend class WebGpuContext; private: - static const webgpu::BufferManager& Get(const ComputeContext& context); + static const webgpu::BufferManager& Get(const ComputeContextBase& context); }; - ComputeContext(OpKernelContext& kernel_context, - const OpKernel& op_kernel, - const WebGpuExecutionProvider& ep, - WebGpuContext& webgpu_context); + ComputeContextBase(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel); - ~ComputeContext() = default; + ~ComputeContextBase() = default; + + // + // Get the node name. + // + inline decltype(auto) NodeName() const { + return op_kernel_.Node().Name(); + } + + // + // Get the operator type. + // + inline decltype(auto) OpType() const { + return op_kernel_.Node().OpType(); + } // - // Get various information from the context. + // Get various information from the WebGPU context. // inline const wgpu::AdapterInfo& AdapterInfo() const { @@ -57,9 +76,6 @@ class ComputeContext final { inline bool HasFeature(wgpu::FeatureName feature) const { return webgpu_context_.DeviceHasFeature(feature); } - inline bool IsGraphCaptureEnabled() const { - return ep_.IsGraphCaptureEnabled(); - } #if !defined(__wasm__) inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const { return webgpu_context_.SubgroupMatrixConfigs(); @@ -67,17 +83,57 @@ class ComputeContext final { #endif // - // Get the kernel context. + // Get Split-K configuration. // - inline OpKernelContext& KernelContext() { - return kernel_context_; + inline const SplitKConfig& GetSplitKConfig() const { + return webgpu_context_.GetSplitKConfig(); + } + + // + // Get whether graph capture is enabled. + // + inline bool IsGraphCaptureEnabled() const { + return ep_.IsGraphCaptureEnabled(); } // // Get the logger. // inline const logging::Logger& Logger() const { - return kernel_context_.Logger(); + return *ep_.GetLogger(); + } + + // + // Run a compute shader program. + // + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } + + protected: + WebGpuContext& webgpu_context_; + const WebGpuExecutionProvider& ep_; + const OpKernel& op_kernel_; +}; + +// +// Class ComputeContext provides all information a `ComputeContextBase` provides, and also +// access to `OpKernelContext` for input and output tensors. +// +class ComputeContext final : public ComputeContextBase { + public: + ComputeContext(WebGpuContext& webgpu_context, + const WebGpuExecutionProvider& ep, + const OpKernel& op_kernel, + OpKernelContext& kernel_context); + + ~ComputeContext() = default; + + // + // Get the kernel context. + // + inline OpKernelContext& KernelContext() { + return kernel_context_; } // @@ -145,25 +201,8 @@ class ComputeContext final { return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst); } - // - // Run a compute shader program. - // - inline Status RunProgram(const ProgramBase& program) { - return webgpu_context_.Run(*this, program); - } - - // - // Get Split-K configuration. - // - // `split_k_config_` won't be initialized until the first call to this method. - // - const SplitKConfig& GetSplitKConfig(); - private: - WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; - const OpKernel& op_kernel_; - const WebGpuExecutionProvider& ep_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 82645e30082e6..3c974ef5133c0 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -322,11 +322,14 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { round_str = "round"; } - std::string use_sqrt_for_pow; + std::string use_pow_shortcut; if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + // use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0 // use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5 - use_sqrt_for_pow = - " else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" + use_pow_shortcut = + " else if (b == 2.0) {\n" + " return a * a;\n" + " } else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" " return sqrt(a);\n" " }\n"; } @@ -337,7 +340,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { " } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n" " return input_a_element_t(pow(f32(a), b)); // NaN\n" " }\n" - << use_sqrt_for_pow + << use_pow_shortcut << " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n" << "}\n" "fn pow_v(a : vec4, b : vec4) -> vec4 {\n" diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 6aefa90a59285..c26b58a7af1f4 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -93,18 +93,21 @@ Status ApplyGemmPacked(const Tensor* a, } const uint32_t TILE_SIZE = 32; - const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE; program.CacheHint(alpha, transA, transB, c_is_scalar) .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) - .SetDispatchGroupSize(num_tile_n, num_tile_m, 1) + .SetDispatchGroupSize(dispatch_x, dispatch_y, 1u) .SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z) .AddUniformVariables({{alpha}, {beta}, - {M}, /* dim_a_outer */ - {N}, /* dim_b_outer */ - {K}} /*dim_inner */ + {M}, /* dim_a_outer */ + {N}, /* dim_b_outer */ + {K}, /*dim_inner */ + {dispatch_x}, /* logical_dispatch_x */ + {dispatch_y}, /* logical_dispatch_y */ + {1u}} /* logical_dispatch_z */ ); return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.h b/onnxruntime/core/providers/webgpu/math/gemm_packed.h index dce5164693aa8..cb89ccefba313 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.h @@ -32,7 +32,10 @@ class GemmProgram final : public Program { {"beta", ProgramUniformVariableDataType::Float32}, {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 7cbc7f6a4a821..89718149cea88 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -117,6 +117,20 @@ void HandleMatMulWithSplitK( } } +// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in +// `ProgramBase.SetDispatchGroupSize()` may be normalized in +// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use +// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. +void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) { + shader.MainFunctionBody() + << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" + << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" + << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" + << " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n" + << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; +} + } // namespace void MatMulReadFnSource(ShaderHelper& shader, @@ -274,20 +288,22 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << "const innerElementSize = " << inner_elements_size << ";\n" << "const tileInner = " << tile_inner << ";\n"; + InitializeLogicalWorkgroupIDAndGlobalID(shader); + shader.MainFunctionBody() << " let localRow = i32(local_id.y);\n" << " let tileRow = localRow * rowPerThread;\n" << " let tileCol = i32(local_id.x);\n" - << " let globalRow = i32(global_id.y) * rowPerThread;\n" - << " let globalCol = i32(global_id.x);\n" - << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << " let globalRow = i32(logical_global_id.y) * rowPerThread;\n" + << " let globalCol = i32(logical_global_id.x);\n" + << " let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" + << " let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" << " var acc: array, rowPerThread>;\n"; if (split_k) { // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from - // `kSplitK * i32(global_id.z)`. + // `kSplitK * i32(logical_global_id.z)`. // // For example: considering computing Y = (X * W + B) in one workgroup. // Let kSplitK = 2, B = [d1, d2] @@ -305,15 +321,15 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) // Workgroup3: compute (C1 * C2) // In each workgroup: - // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z` + // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `logical_global_id.z` // - When the computation in each workgroup is completed, add the result to Y with several // atomic built-in functions in `HandleMatMulWithSplitK()`. shader.MainFunctionBody() << "const kSplitK = " << split_dim_inner << ";\n" << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" - << " var kStart = kSplitK * i32(global_id.z);\n" + << " var kStart = kSplitK * i32(logical_global_id.z);\n" - // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate + // When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate // the index of split-k instead of batch. << " let batch = 0;\n" << " let batchIndices = 0u;\n"; @@ -321,7 +337,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, shader.MainFunctionBody() << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" - << " let batch = i32(global_id.z);\n" + << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : ""); } @@ -498,7 +514,9 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, << "const colPerThread = " << elements_per_thread_x << ";\n" << "const tileInner = " << tile_inner << ";\n"; - shader.MainFunctionBody() << " let batch = i32(global_id.z);\n" + InitializeLogicalWorkgroupIDAndGlobalID(shader); + + shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" @@ -507,10 +525,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, shader.MainFunctionBody() << "let tileRow = i32(local_id.y) * rowPerThread;\n" << "let tileCol = i32(local_id.x) * colPerThread;\n" - << "let globalRow = i32(global_id.y) * rowPerThread;\n" - << "let globalCol = i32(global_id.x) * colPerThread;\n" - << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << "let globalRow = i32(logical_global_id.y) * rowPerThread;\n" + << "let globalCol = i32(logical_global_id.x) * colPerThread;\n" + << "let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" + << "let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 55c2c5773cc1f..72dd235eb820a 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -256,8 +256,6 @@ Status ComputeMatMul(ComputeContext* context, // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the // number of splits along `dim_inner`. - // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize - // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`. split_dim_inner = split_k_config.GetSplitDimInner(); dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; @@ -271,7 +269,7 @@ Status ComputeMatMul(ComputeContext* context, .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) - .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z) diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 143ba61c99e13..dbd193bc38f58 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -24,7 +24,10 @@ class MatMulProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); bool NeedSplitK() const; diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc index 2f34aa21c8309..bf3bb53341418 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -64,7 +64,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { int components = input.NumComponents(); const std::string thread_max_decl = is_fp32_ - ? "var thread_max = x_value_t(-3.402823e+38f);\n" + ? "var thread_max = x_value_t(-3.4028234663852886e+38f);\n" : "var thread_max = x_value_t(-65504.0h);\n"; // Define shared memory for row max and row sum diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 77fa46cb87518..4fff736fd2f32 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -216,6 +216,46 @@ Status Conv::ComputeInternal(ComputeContext& context return context.RunProgram(conv2d_mm_program); } +template +Status Conv::PrePackInternal(ComputeContextBase& /* context */, + const Tensor& tensor, + int input_idx, + AllocatorPtr /* alloc */, + /*out*/ bool& is_packed) { + is_packed = false; + + if constexpr (is_channels_last) { + if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) { + // only deal with 4D NHWC weights + + // TODO: implement weight transpose for pre-pack here + // Conv::ComputeInternal() should be updated to reflect the change: + // - if the initializer is packed, `context.Input(1)` will be nullptr. + // - in this case, use `transposed_kernel_` instead. + + // // Step.1 - calculate transposed weight shape + // TensorShape transposed_kernel_shape{tensor.Shape()[2], + // tensor.Shape()[3], + // tensor.Shape()[1], + // tensor.Shape()[0]}; + + // // Step.2 - create transposed weight tensor + // transposed_kernel_ = std::make_unique(tensor.DataType(), transposed_kernel_shape, alloc); + + // // Step.3 - do transpose + // size_t perm[] = {2, 3, 1, 0}; + // ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, + // perm, + // tensor, + // *transposed_kernel_)); + + // is_packed = true; // set this flag to true so that ORT will release the initializer tensor + } + } + + return Status::OK(); +} + // Explicit template instantiation for FusedConv template class Conv; template class Conv; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h index cafaa272c0613..5bf94a459a44a 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.h +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -23,9 +23,16 @@ class Conv : public WebGpuKernel { } Status ComputeInternal(ComputeContext& context) const override; + Status PrePackInternal(ComputeContextBase& context, + const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed) override; + protected: ConvAttributes conv_attrs_; Activation activation_; + std::unique_ptr transposed_kernel_; // should only have value when `is_initializer` AND `is_4D` AND `is_NHWC` }; Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc index 2d5424c52a3f2..c66f2cbd582d9 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc @@ -226,7 +226,10 @@ Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::v {static_cast(dim_inner)}, {pads}, {strides}, - {dilations}}); + {dilations}, + {dispatch[0]}, + {dispatch[1]}, + {dispatch[2]}}); return program; } diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h index d7cc08aae26f3..e161bffb0c503 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h @@ -38,7 +38,10 @@ class Conv2dMMProgram final : public Program { {"dim_inner", ProgramUniformVariableDataType::Uint32}, {"pads", ProgramUniformVariableDataType::Uint32}, {"strides", ProgramUniformVariableDataType::Uint32}, - {"dilations", ProgramUniformVariableDataType::Uint32}); + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, + {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); private: const Activation& activation_; diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 7e8b434431781..5f59fecc425e2 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -92,14 +92,28 @@ Status SliceProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +static std::vector getInt64Input(const Tensor* tensor) { + if (tensor->IsDataType()) { + return std::vector(tensor->DataAsSpan().begin(), tensor->DataAsSpan().end()); + } + ORT_ENFORCE(tensor->IsDataType(), "Expected tensor of type int32 or int64"); + std::vector result; + auto span = tensor->DataAsSpan(); + result.reserve(span.size()); + for (auto v : span) { + result.push_back(static_cast(v)); + } + return result; +} + Status Slice::ComputeInternal(ComputeContext& context) const { // READ INPUTS const Tensor* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); auto input_rank = input_shape.NumDimensions(); - auto starts_raw = attr_starts_.empty() ? context.Input(1)->DataAsSpan() : gsl::make_span(attr_starts_); - auto ends_raw = attr_ends_.empty() ? context.Input(2)->DataAsSpan() : gsl::make_span(attr_ends_); + auto starts_raw = attr_starts_.empty() ? getInt64Input(context.Input(1)) : attr_starts_; + auto ends_raw = attr_ends_.empty() ? getInt64Input(context.Input(2)) : attr_ends_; ORT_ENFORCE(starts_raw.size() == ends_raw.size(), "starts and ends must have the same size"); @@ -126,7 +140,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { axes_default.push_back(i); } } - auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? gsl::make_span(axes_default) : axes_tensor->DataAsSpan()) : gsl::make_span(attr_axes_); + auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? axes_default : getInt64Input(axes_tensor)) : attr_axes_; std::vector steps_default; if (steps_tensor == nullptr) { @@ -135,7 +149,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { steps_default.push_back(1); } } - auto steps_raw = steps_tensor == nullptr ? gsl::make_span(steps_default) : steps_tensor->DataAsSpan(); + auto steps_raw = steps_tensor == nullptr ? steps_default : getInt64Input(steps_tensor); // get final axes std::vector axes, axes_fixed; diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index cec321d0da80e..5415d4a5ead5b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -108,7 +108,7 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, +Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output) { const auto& input_shape = input.Shape(); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index b62a419fa12bc..5e9ccc6750cd6 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase { Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { } Status ComputeInternal(ComputeContext& context) const override; - static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output); + static Status DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output); constexpr static uint32_t TILE_SIZE = 16; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 28decb076951e..b8d5adc421124 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -147,6 +147,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // create program manager program_mgr_ = std::make_unique(*this); + // create split-k config + split_k_config_ = std::make_unique(adapter_info_); + // set query type #if !defined(__wasm__) if (DeviceHasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { @@ -178,7 +181,7 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } -Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { +Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); @@ -288,8 +291,8 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch); if (is_profiling_) { - PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(), - context.KernelContext().GetOpType(), + PendingKernelInfo pending_kernel_info(context.NodeName(), + context.OpType(), program.Name(), key, inputs, @@ -442,7 +445,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; WGPUBuffer uniform_buffer = nullptr; - const webgpu::BufferManager& buffer_mgr = ComputeContext::BufferManagerAccessor::Get(context); + const webgpu::BufferManager& buffer_mgr = ComputeContextBase::BufferManagerAccessor::Get(context); if (uniform_buffer_total_size > 0) { std::vector uniform_data_buffer(uniform_buffer_total_size); @@ -910,13 +913,6 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; std::once_flag WebGpuContextFactory::init_default_flag_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index bd7dae75f2e2d..84dfb47ef4687 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,7 +5,6 @@ #include #include -#include #include "core/providers/webgpu/webgpu_external_header.h" @@ -23,7 +22,7 @@ class Tensor; namespace webgpu { class WebGpuContext; -class ComputeContext; +class ComputeContextBase; class ProgramBase; // Definition for CapturedCommandInfo in the webgpu namespace @@ -152,6 +151,13 @@ class WebGpuContext final { return validation_mode_; } + // + // Get Split-K configuration. + // + const SplitKConfig& GetSplitKConfig() const { + return *split_k_config_; + } + void StartProfiling(); void CollectProfilingData(profiling::Events& events); void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); @@ -170,16 +176,9 @@ class WebGpuContext final { // Status PopErrorScope(); - Status Run(ComputeContext& context, const ProgramBase& program); + Status Run(ComputeContextBase& context, const ProgramBase& program); void OnRunEnd(); - // - // Get Split-K configuration. - // - // `split_k_config_` won't be initialized until the first call to this method. - // - const SplitKConfig& GetSplitKConfig(); - private: enum class TimestampQueryType { None = 0, @@ -277,7 +276,7 @@ class WebGpuContext final { uint32_t num_pending_dispatches_ = 0; const uint32_t max_num_pending_dispatches_ = 16; - std::optional split_k_config_; + std::unique_ptr split_k_config_; // profiling TimestampQueryType query_type_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index e0b84fef51f1f..6b764d51bcf75 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -794,8 +794,7 @@ using namespace webgpu; WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, WebGpuContext& context, WebGpuExecutionProviderConfig&& config) - : IExecutionProvider{kWebGpuExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)}, + : IExecutionProvider{kWebGpuExecutionProvider, WebGpuDevice}, context_id_{context_id}, context_{context}, preferred_data_layout_{config.data_layout}, @@ -935,13 +934,14 @@ std::unique_ptr WebGpuExecutionProvider::GetEx std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, std::string_view node_op_type, DataLayout target_data_layout) const { - if (target_data_layout != DataLayout::NHWC) { - return std::nullopt; - } - // NHWC for Resize operator is not implemented on kWebGpuExecutionProvider if (node_domain == kOnnxDomain && node_op_type == "Resize") { - return false; + return target_data_layout != DataLayout::NHWC; + } + + // WebGPU perfer NCHW for InstanceNormalization due to a better performance + if (node_domain == kOnnxDomain && node_op_type == "InstanceNormalization") { + return target_data_layout != DataLayout::NHWC; } return std::nullopt; diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc index 8d6ae6caeaf83..ea38e9415e1fe 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc @@ -11,25 +11,58 @@ namespace webgpu { WebGpuKernel::WebGpuKernel(const OpKernelInfo& info) : OpKernel(info), - ep_(*static_cast(info.GetExecutionProvider())) { + ep_(*static_cast(info.GetExecutionProvider())), + webgpu_context_(WebGpuContextFactory::GetContext(ep_.GetDeviceId())) { } Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const { - WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId()); - ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context}; + ComputeContext context{webgpu_context_, + ep_, + *this, + *p_op_kernel_context}; - if (webgpu_context.ValidationMode() >= ValidationMode::Full) { - webgpu_context.PushErrorScope(); + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + webgpu_context_.PushErrorScope(); } Status s = ComputeInternal(context); - if (webgpu_context.ValidationMode() >= ValidationMode::Full) { - ORT_RETURN_IF_ERROR(webgpu_context.PopErrorScope()); + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); } return s; } +Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { + ComputeContextBase context{webgpu_context_, ep_, *this}; + + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + webgpu_context_.PushErrorScope(); + } + + // Currently, ORT does not allow using prepacked weights in non-CPU EPs. + // So we do not pass prepacked_weights to PrePackInternal. + // Kernel implementation that supports prepacking should manage its own storage. + + Status s = PrePackInternal(context, tensor, input_idx, alloc, is_packed); + + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); + } + + return s; +} + +Status WebGpuKernel::PrePackInternal(ComputeContextBase& /*context*/, + const Tensor& /*tensor*/, + int /*input_idx*/, + AllocatorPtr /*alloc*/, + /*out*/ bool& is_packed) { + is_packed = false; + return Status::OK(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 3c750e305421c..2c57991c6ee35 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -23,8 +23,41 @@ class WebGpuKernel : public OpKernel { virtual Status ComputeInternal(ComputeContext& context) const = 0; + // Overrides OpKernel::PrePack to handle constant tensor pre-processing for WebGPU kernels. + // This method creates a ComputeContextBase and delegates to PrePackInternal. + // + // NOTE: Currently, ORT does not allow using prepacked weights in non-CPU EPs, so the + // prepacked_weights parameter is not passed to PrePackInternal. Kernel implementations + // that support prepacking should manage their own storage. + Status PrePack(const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + // Virtual method that allows derived kernels to pre-process constant tensors during initialization. + // + // This method is called during kernel initialization when constant tensors are available, + // allowing kernels to perform operations like tensor transposition or format conversion + // before the first Compute call. + // + // @param context The WebGPU compute context base providing access to the execution environment. + // @param tensor The constant tensor to potentially pre-process. + // @param input_idx The index of this input in the kernel's input list. + // @param alloc The allocator to use for any new tensor allocations. + // @param is_packed Output parameter. Set to true if the tensor was pre-packed/processed, + // false otherwise. The default implementation sets this to false. + // + // @return Status::OK() on success, or an error status on failure. + virtual Status PrePackInternal(ComputeContextBase& context, + const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed); + private: const WebGpuExecutionProvider& ep_; + WebGpuContext& webgpu_context_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 568d29a96cb88..5fd24b2bff037 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -21,27 +21,24 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components return TensorShape(shape_vector); } -SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) { - SplitKConfig config = {}; - +SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) { if (adapter_info.vendor == std::string_view{"intel"}) { if (adapter_info.architecture == std::string_view{"xe-2lpg"} || adapter_info.architecture == std::string_view{"xe-2hpg"} || adapter_info.architecture == std::string_view{"xe-lpg"} || adapter_info.architecture == std::string_view{"gen-12hp"}) { - config.enable_split_k_ = true; + enable_split_k_ = true; // Below thresholds are only verified on the above Intel GPUs without any regressions. The // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more // atomic calls for each output value. - config.split_dim_inner_ = 256; - config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; - config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; - config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; + split_dim_inner_ = 256; + min_dim_inner_with_split_k_ = split_dim_inner_ * 2; + max_dim_inner_with_split_k_ = split_dim_inner_ * 9; + max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; } } - return config; } bool SplitKConfig::UseSplitK( diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index d45b9bf4dd119..7d5ab5fea8006 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -91,9 +91,12 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; } +/** + * Configuration for Split-K optimization (Conv|MatMul). + */ class SplitKConfig { public: - static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info); + explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4d4dea9cb444c..ab3932e7abfb4 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2943,6 +2943,8 @@ Status InferenceSession::Run(const RunOptions& run_options, << cached_execution_provider_for_graph_replay_.Type() << " CUDA Graph for this model with tag: " << run_options.run_tag << " with graph annotation id: " << graph_annotation_id; + // log evaluation start to trace logging provider + env.GetTelemetryProvider().LogEvaluationStart(session_id_); ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id)); } else { InlinedVector exec_providers_to_stop; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 6189e6ca7f012..4cb21b80109c8 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -404,6 +404,7 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model session))); } + Env::Default().GetTelemetryProvider().LogCompileModel(session->GetCurrentSessionId()); ORT_RETURN_IF_ERROR(ToStatusAndRelease(InitializeSession(session_options, *session))); return Status::OK(); } diff --git a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc index 70c7a5b2bcdcb..5deef01cd783e 100644 --- a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc +++ b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc @@ -22,10 +22,17 @@ namespace test { // --------- Helpers --------- +// cuda errors are sticky and may affect subsequent API calls. +// we want to clear the error if when supported check fails. +void ClearCudaError() { + ORT_IGNORE_RETURN_VALUE(::cudaGetLastError()); +} + static bool IsCudaMemPoolSupported() { int ort_cuda_rt_version = 0; cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version); if (cuda_status != cudaSuccess) { + ClearCudaError(); return false; } @@ -36,6 +43,7 @@ static bool IsCudaMemPoolSupported() { int ort_cuda_driver_version = 0; cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version); if (cuda_status != cudaSuccess) { + ClearCudaError(); return false; } @@ -65,9 +73,10 @@ static bool IsCudaMemPoolSupported() { cudaMemPool_t pool; auto cuda_error = cudaMemPoolCreate(&pool, &props); if (cuda_error != cudaSuccess) { + ClearCudaError(); return false; } - cuda_error = cudaMemPoolDestroy(pool); + ORT_IGNORE_RETURN_VALUE(cudaMemPoolDestroy(pool)); return true; } @@ -80,7 +89,9 @@ static ::cudaStream_t NewCudaStream() { } static void DestroyCudaStream(::cudaStream_t s) { - if (s) (void)::cudaStreamDestroy(s); + if (s) { + EXPECT_EQ(cudaSuccess, ::cudaStreamDestroy(s)); + } } static void TouchDevice(void* p, size_t bytes, ::cudaStream_t s, unsigned char value = 0xAB) { diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index d8cc56d738175..af9706855ee3c 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -203,6 +203,48 @@ TEST_P(TypeTests, IOTypes) { } } +TEST(NvExecutionProviderTest, TestSessionOutputs) { + /* + * Model #1: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + { + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + + auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 4); + } + + /* + * Model #2: + * + * "X" ---> Dropout ---> MatMul ---> "Y" + * ^ | + * | | + * "W" ------ ----> Can't be graph's output + * + */ + { + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + + auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 1); + } +} + INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, ::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, diff --git a/onnxruntime/test/providers/qnn/README.md b/onnxruntime/test/providers/qnn/README.md new file mode 100644 index 0000000000000..c3d0c720a1aa4 --- /dev/null +++ b/onnxruntime/test/providers/qnn/README.md @@ -0,0 +1,70 @@ +# ONNX Runtime QNN Execution Provider Tests +## Overview +1. The `onnxruntime/test/providers/qnn` directory contains integration tests for the Qualcomm Neural Network (QNN) execution provider. +2. Most testcases run an ONNX model through the QNN-EP, then verifies the inference result against the one on CPU-EP + +## Building the Tests +The tests are built as part of the regular ONNX Runtime build. After a successful build you will have an executable named +- onnxruntime_provider_test.exe (Windows) +- onnxruntime_provider_test (Linux/macOS) + +## Running the Tests +1. QNN supports several backends. You can use the standard Google‑Test syntax for filtering: + - `onnxruntime_provider_test.exe --gtest_filter=QnnCPUBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnHTPBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnGPUBackendTests.*` + - `onnxruntime_provider_test.exe --gtest_filter=QnnIRBackendTests.*` +2. Saving Test Artifacts + - For debugging it is often helpful to keep the intermediate files that the tests generate. The following environment + variables are recognized by the test binary: + - `QNN_DUMP_ONNX`: Saves the input ONNX model used for the test + - `QNN_DUMP_JSON`: Save json qnn graph with provider_option `dump_json_qnn_graph` + - `QNN_DUMP_DLC`: Saves the compiled QNN DLC file by specifying the provider_option `backend_path` to `QnnIr.dll` + - The artifacts will be saved to a directory named with `_` + ``` + . + ├── QnnCPUBackendTests_BatchNorm2D_fp32 # RunQnnModelTest + │ ├── dumped_f32_model.onnx # float32 ONNX model + │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + │ └── QNNExecutionProvider_QNN_XXXX_X_X.json + ├── QnnHTPBackendTests_BatchNorm_FP16 # TestFp16ModelAccuracy + │ ├── dumped_f16_model.onnx # float16 ONNX model + │ ├── dumped_f32_model.onnx # float32 ONNX model + │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + │ └── QNNExecutionProvider_QNN_XXXX_X_X.json + └── QnnHTPBackendTests_BatchNorm2D_U8U8S32 # TestQDQModelAccuracy + ├── dumped_f32_model.onnx # float32 ONNX model + ├── dumped_qdq_model.onnx # QDQ ONNX model + ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc + └── QNNExecutionProvider_QNN_XXXX_X_X.json + + # All artifact files are placed under the current working directory from which the test binary is invoked. + ``` +3. Verbose + - `QNN_VERBOSE`: Sets the ONNX Runtime log level to `ORT_LOGGING_LEVEL_VERBOSE` + +4. You can enable any combination of these environment variables, for example: + - On Linux/macOS + ```bash + export QNN_DUMP_ONNX=1 + export QNN_DUMP_JSON=1 + export QNN_DUMP_DLC=1 + export QNN_VERBOSE=1 + ``` + - On Windows + ```cmd + set QNN_DUMP_ONNX=1 + set QNN_DUMP_JSON=1 + set QNN_DUMP_DLC=1 + set QNN_VERBOSE=1 + ``` + ```ps1 + $Env:QNN_DUMP_ONNX = "1" + $Env:QNN_DUMP_JSON = "1" + $Env:QNN_DUMP_DLC = "1" + $Env:QNN_VERBOSE = "1" + ``` + +# Note +- An issue on QNN backends can prevent the test artifacts from being successfully saved. +- The `onnxruntime_provider_test.exe` does not automatically delete the artifact directories, so you may want to prune them after a debugging session. diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 1c70f4012090e..15a9132aaa16c 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -101,6 +101,12 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, std::function* ep_graph_checker) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_json() || + QNNTestEnvironment::GetInstance().dump_dlc()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; verification_params.fp32_abs_err = fp32_abs_err; @@ -110,6 +116,10 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -123,7 +133,27 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); + + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); + } + TryEnableQNNSaver(provider_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + provider_options["dump_qnn_ir_dlc"] = "1"; + provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = output_dir.string(); + } RunAndVerifyOutputsWithEP(AsByteSpan(model_data.data(), model_data.size()), "QNN_EP_TestLogID", QnnExecutionProviderWithOptions(provider_options), helper.feeds_, verification_params, @@ -134,11 +164,21 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, logging::Severity log_severity, std::function* ep_graph_checker) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -152,7 +192,27 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); + + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); + } + TryEnableQNNSaver(provider_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + provider_options["dump_qnn_ir_dlc"] = "1"; + provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = output_dir.string(); + } SessionOptions so; so.session_logid = "QNN_EP_TestLogID"; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index aeb3a9a114871..4d4f795d161b1 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -499,6 +499,77 @@ struct QDQTolerance { float value; }; +class QNNTestEnvironment { + public: + // Delete copy constructor and assignment operator + QNNTestEnvironment(const QNNTestEnvironment&) = delete; + QNNTestEnvironment& operator=(const QNNTestEnvironment&) = delete; + + // Static method to get the singleton instance + static QNNTestEnvironment& GetInstance() { + static QNNTestEnvironment instance; + return instance; + } + + bool dump_onnx() const { return dump_onnx_; } + bool dump_json() const { return dump_json_; } + bool dump_dlc() const { return dump_dlc_; } + bool verbose() const { return verbose_; } + + std::filesystem::path CreateTestcaseDirs() { + std::string test_suite_name = ::testing::UnitTest::GetInstance()->current_test_info()->test_suite_name(); + std::string test_name = ::testing::UnitTest::GetInstance()->current_test_info()->name(); + std::filesystem::path output_dir = std::filesystem::current_path() / (test_suite_name + "_" + test_name); + std::filesystem::create_directories(output_dir); + + return output_dir; + } + + private: + // Private constructor for singleton + QNNTestEnvironment() { + ParseEnvironmentVars(); + } + + // Helper function to check if an environment variable is set + bool IsEnvVarSet(const char* name) { + const char* value = std::getenv(name); + if (value == nullptr) { + return false; + } + + // Consider the variable set if it's not empty and not "0" + return *value != '\0' && *value != '0'; + } + + void ParseEnvironmentVars() { + if (IsEnvVarSet("QNN_DUMP_ONNX")) { + std::cout << "[QNN only] ONNX model dumping enabled via environment variable." << std::endl; + dump_onnx_ = true; + } + + if (IsEnvVarSet("QNN_DUMP_JSON")) { + std::cout << "[QNN only] Json QNN Graph dumping enabled via environment variable." << std::endl; + dump_json_ = true; + } + + if (IsEnvVarSet("QNN_DUMP_DLC")) { + std::cout << "[QNN only] DLC dumping enabled via environment variable." << std::endl; + dump_dlc_ = true; + } + + if (IsEnvVarSet("QNN_VERBOSE")) { + std::cout << "Verbose enabled via environment variable." << std::endl; + verbose_ = true; + } + } + + bool dump_onnx_ = false; + bool dump_json_ = false; + bool dump_dlc_ = false; + bool verbose_ = false; +}; + /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: * @@ -529,15 +600,21 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}, std::function* qnn_ep_graph_checker = nullptr) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); - - // Uncomment to dump LOGGER() output to stdout. - // logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -551,8 +628,11 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); - // Uncomment to save f32 model to disk for debugging. - // ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, ToPathString("cmp_accuracy.f32.onnx"))); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); + } // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; @@ -594,11 +674,27 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(qdq_model.MainGraph().Resolve()); qdq_model.ToProto().SerializeToString(&qdq_model_data); - // Uncomment to save QDQ model to disk for debugging. - // ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, ToPathString("cmp_accuracy.qdq.onnx"))); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_qdq_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx QDQ model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, dump_path)); + } bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + qnn_options["dump_qnn_ir_dlc"] = "1"; + qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + qnn_options["dump_json_qnn_graph"] = "1"; + qnn_options["json_qnn_graph_dir"] = output_dir.string(); + } std::vector qnn_qdq_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; @@ -743,11 +839,21 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, logging::Severity log_severity = logging::Severity::kERROR, const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}) { + std::filesystem::path output_dir; + if (QNNTestEnvironment::GetInstance().dump_onnx() || + QNNTestEnvironment::GetInstance().dump_dlc() || + QNNTestEnvironment::GetInstance().dump_json()) { + output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); + } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); + if (QNNTestEnvironment::GetInstance().verbose()) { + logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -760,6 +866,12 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); + } + // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; InferenceModel(f32_model_data, "f32_model_logger", {}, ExpectedEPNodeAssignment::All, @@ -796,8 +908,27 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f16_model.MainGraph().Resolve()); f16_model.ToProto().SerializeToString(&f16_model_data); + if (QNNTestEnvironment::GetInstance().dump_onnx()) { + auto dump_path = output_dir / ToPathString("dumped_f16_model.onnx"); + LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float16 model at: " << dump_path; + ASSERT_STATUS_OK(onnxruntime::Model::Save(f16_model, dump_path)); + } + bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); + if (QNNTestEnvironment::GetInstance().dump_dlc()) { + qnn_options["dump_qnn_ir_dlc"] = "1"; + qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); +#if defined(_WIN32) + qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; +#else + qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; +#endif // defined(_WIN32) + } + if (QNNTestEnvironment::GetInstance().dump_json()) { + qnn_options["dump_json_qnn_graph"] = "1"; + qnn_options["json_qnn_graph_dir"] = output_dir.string(); + } std::vector qnn_f16_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 6a6545c68cb4f..dce0d570ec238 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -1,5 +1,6 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "onnxruntime_cxx_api.h" #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" @@ -18,6 +19,8 @@ using namespace std; using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; +extern std::unique_ptr ort_env; + namespace onnxruntime { namespace test { @@ -1360,5 +1363,49 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); } + +TEST(TensorrtExecutionProviderTest, TestSessionOutputs) { + /* + * Model #1: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + { + OrtTensorRTProviderOptionsV2 provider_options; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_TensorRT_V2(provider_options); + + auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 4); + } + + /* + * Model #2: + * + * "X" ---> Dropout ---> MatMul ---> "Y" + * ^ | + * | | + * "W" ------ ----> Can't be graph's output + * + */ + { + OrtTensorRTProviderOptionsV2 provider_options; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_TensorRT_V2(provider_options); + + auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); + Ort::Session session(*ort_env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 1); + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/node_output_not_used.onnx b/onnxruntime/test/testdata/node_output_not_used.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e2726182fddc2c265752e46346735c26e33add4b GIT binary patch literal 189 zcmd=lo3kgAWK)CKji3J%^!XPX8xOg}ig*dpFIGBN$2_zVfB*+Ak RNCFB*q6<2)a4`t*0ss-ID|-L{ literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/node_output_not_used.py b/onnxruntime/test/testdata/node_output_not_used.py new file mode 100644 index 0000000000000..d36d5e9cfd2f8 --- /dev/null +++ b/onnxruntime/test/testdata/node_output_not_used.py @@ -0,0 +1,43 @@ +import onnx +from onnx import TensorProto, helper + + +def create_model_with_node_output_not_used(model_path): + # Create graph + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) + w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + # Dropout node (two outputs) + dropout_node = helper.make_node( + "Dropout", + inputs=["X"], + outputs=["dropout_out", "dropout_mask"], + name="DropoutNode", + ) + + # MatMul node + matmul_node = helper.make_node( + "MatMul", + inputs=["dropout_out", "W"], + outputs=["Y"], + name="MatMulNode", + ) + + graph = helper.make_graph( + nodes=[dropout_node, matmul_node], + name="DropoutMatMulGraph", + inputs=[x, w], + outputs=[y], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)]) + + onnx.checker.check_model(model) + onnx.save(model, model_path) + + print(f"Model saved to: {model_path}") + + +if __name__ == "__main__": + create_model_with_node_output_not_used("node_output_not_used.onnx") diff --git a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx new file mode 100644 index 0000000000000000000000000000000000000000..340c3d420d5746844be0bd3769a174b4e69de801 GIT binary patch literal 393 zcmdW?8(U zIYJ{dP(TSp09}P@53)A4oW!KmoMI_v-~1FM5Fx|~a-n-sVnK!$HwU8tyA{(KCMQO3 zEp8x_k--V -B -V [-H ] " 1>&2; exit 1; } - -ROCM_HOME=/opt/rocm - -while getopts S:B:V:H:I:P: parameter_Option; do - case "${parameter_Option}" in - S) SOURCE_DIR=${OPTARG};; - B) BINARY_DIR=${OPTARG};; - V) ROCM_VERSION=${OPTARG};; - H) ROCM_HOME=${OPTARG};; - I) IMAGE=${OPTARG};; - P) PYTHON_BIN=${OPTARG};; - *) usage ;; - esac -done - -EXIT_CODE=1 - -docker run -e SYSTEM_COLLECTIONURI --rm \ - --security-opt seccomp=unconfined \ - --shm-size=1024m \ - --user $UID:$(id -g $USER) \ - -e NIGHTLY_BUILD \ - --volume $SOURCE_DIR:/onnxruntime_src \ - --volume $BINARY_DIR:/build \ - --volume /data/models:/build/models:ro \ - --volume /data/onnx:/data/onnx:ro \ - --workdir /onnxruntime_src \ - $IMAGE \ - /bin/bash -c "${PYTHON_BIN:-python} /onnxruntime_src/tools/ci_build/build.py --config Release --build_dir /build --parallel --use_rocm --use_binskim_compliant_compile_flags --rocm_version=$ROCM_VERSION --rocm_home $ROCM_HOME --nccl_home $ROCM_HOME --build_shared_lib --skip_submodule_sync --skip_tests --cmake_extra_defines FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER && cd /build/Release && make install DESTDIR=/build/installed" - - -EXIT_CODE=$? - -set -e -exit $EXIT_CODE diff --git a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh deleted file mode 100755 index 0be64d96f3a34..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -set -e -x - -# version -ROCM_VERSION=6.2.3 - -while getopts "r:" parameter_Option -do case "${parameter_Option}" -in -r) ROCM_VERSION=${OPTARG};; -esac -done - -tee /etc/yum.repos.d/amdgpu.repo < Date: Tue, 2 Dec 2025 22:09:41 -0800 Subject: [PATCH 137/138] Revert "Sync with Microsoft ONNX Runtime - 03/12/2025 (#867)" This reverts commit 39d6db58a737449b0bd3ddc9b5a8777f4426d99d. --- .github/workflows/android.yml | 6 +- .github/workflows/cffconvert.yml | 2 +- .github/workflows/codeql.yml | 2 +- .../workflows/gradle-wrapper-validation.yml | 2 +- .github/workflows/ios.yml | 2 +- .github/workflows/lint.yml | 8 +- .../linux-wasm-ci-build-and-test-workflow.yml | 2 +- .github/workflows/linux_cuda_ci.yml | 2 +- .github/workflows/linux_minimal_build.yml | 20 +- .github/workflows/linux_tensorrt_ci.yml | 2 +- .github/workflows/mac.yml | 4 +- .../macos-ci-build-and-test-workflow.yml | 2 +- .github/workflows/pr_checks.yml | 2 +- .github/workflows/publish-c-apidocs.yml | 2 +- .github/workflows/publish-csharp-apidocs.yml | 2 +- .github/workflows/publish-java-apidocs.yml | 2 +- .github/workflows/publish-js-apidocs.yml | 2 +- .../workflows/publish-objectivec-apidocs.yml | 2 +- .github/workflows/publish-python-apidocs.yml | 2 +- .github/workflows/react_native.yml | 8 +- .github/workflows/reusable_linux_build.yml | 2 +- .github/workflows/web.yml | 2 +- .github/workflows/windows-web-ci-workflow.yml | 2 +- .github/workflows/windows_build_x64_asan.yml | 2 +- .github/workflows/windows_cuda.yml | 4 +- .github/workflows/windows_dml.yml | 2 +- .github/workflows/windows_openvino.yml | 2 +- .github/workflows/windows_qnn_x64.yml | 2 +- .github/workflows/windows_tensorrt.yml | 4 +- .github/workflows/windows_webgpu.yml | 6 +- .../windows_x64_debug_build_x64_debug.yml | 2 +- .../windows_x64_release_build_x64_release.yml | 2 +- ...build_x64_release_ep_generic_interface.yml | 2 +- ..._x64_release_vitisai_build_x64_release.yml | 2 +- .../workflows/windows_x64_release_xnnpack.yml | 2 +- .github/workflows/windows_x86.yml | 2 +- dockerfiles/Dockerfile.rocm | 24 + dockerfiles/README.md | 17 +- dockerfiles/scripts/install_rocm_deps.sh | 84 ++ js/package-lock.json | 144 ++- js/react_native/package-lock.json | 125 +-- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/softmax.ts | 2 +- .../contrib_ops/rocm/bert/attention.cu | 215 ++++ onnxruntime/contrib_ops/rocm/bert/attention.h | 33 + .../contrib_ops/rocm/bert/attention_impl.cu | 435 +++++++++ .../contrib_ops/rocm/bert/attention_impl.h | 180 ++++ .../contrib_ops/rocm/bert/attention_softmax.h | 465 +++++++++ .../bert/batched_gemm_permute_pipelines.cuh | 125 +++ .../impl.cuh | 177 ++++ .../impl_fp16.cu | 60 ++ .../impl_fp16_biased.cu | 60 ++ .../impl_fp16_biased_biased.cu | 60 ++ ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 915 ++++++++++++++++++ .../rocm/bert/decoder_attention_impl.h | 46 + .../contrib_ops/rocm/bert/elementwise.h | 84 ++ .../rocm/bert/elementwise_impl/impl.cuh | 256 +++++ .../bert/elementwise_impl/impl_fastgelu.cu | 9 + .../rocm/bert/elementwise_impl/impl_gelu.cu | 9 + .../rocm/bert/elementwise_impl/impl_relu.cu | 8 + .../contrib_ops/rocm/bert/gemm_fast_gelu.cc | 75 ++ .../contrib_ops/rocm/bert/gemm_fast_gelu.h | 23 + .../rocm/bert/gemm_fast_gelu_ck.cuh | 133 +++ .../rocm/bert/gemm_fast_gelu_common.h | 47 + .../rocm/bert/gemm_fast_gelu_impl.cu | 91 ++ .../rocm/bert/gemm_fast_gelu_impl.h | 40 + .../rocm/bert/gemm_fast_gelu_tunable.cuh | 83 ++ .../rocm/bert/group_query_attention.cu | 530 ++++++++++ .../rocm/bert/group_query_attention.h | 38 + .../contrib_ops/rocm/bert/layer_norm.cuh | 270 ++++++ .../rocm/bert/multihead_attention.cu | 286 ++++++ .../rocm/bert/multihead_attention.h | 51 + .../contrib_ops/rocm/bert/skip_layer_norm.cc | 132 +++ .../contrib_ops/rocm/bert/skip_layer_norm.h | 26 + .../rocm/bert/skip_layer_norm_impl.cu | 86 ++ .../rocm/bert/skip_layer_norm_impl.h | 31 + .../rocm/bert/skip_layer_norm_impl_kernel.h | 162 ++++ .../rocm/bert/skip_layer_norm_tunable_op.h | 161 +++ .../rocm/bert/transformer_common.cc | 37 + .../rocm/bert/transformer_common.h | 46 + .../rocm/diffusion/group_norm_ck.cuh | 105 ++ .../diffusion/group_norm_ck_impl/impl.cuh | 130 +++ .../diffusion/group_norm_ck_impl/impl_fp16.cu | 39 + .../diffusion/group_norm_ck_impl/impl_fp32.cu | 39 + .../rocm/diffusion/group_norm_common.h | 56 ++ .../rocm/diffusion/group_norm_impl.cu | 76 ++ .../rocm/diffusion/group_norm_triton.cuh | 105 ++ .../rocm/diffusion/group_norm_triton.py | 135 +++ .../rocm/diffusion/group_norm_tunable_op.h | 220 +++++ .../contrib_ops/rocm/diffusion/nhwc_conv.cc | 27 + onnxruntime/contrib_ops/rocm/fused_conv.cc | 439 +++++++++ .../contrib_ops/rocm/math/gemm_float8.cu | 213 ++++ .../contrib_ops/rocm/math/gemm_float8_ck.cuh | 276 ++++++ .../math/gemm_float8_ck_impl/add_instance.cu | 124 +++ ...xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu | 97 ++ ...k_f16_f8_f16_mk_kn_mn_instance_original.cu | 80 ++ ...xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu | 94 ++ ...k_f8_f16_f16_mk_kn_mn_instance_original.cu | 97 ++ .../contrib_ops/rocm/rocm_contrib_kernels.cc | 347 +++++++ .../contrib_ops/rocm/rocm_contrib_kernels.h | 14 + .../contrib_ops/webgpu/bert/attention.cc | 6 +- .../webgpu/bert/flash_attention.cc | 6 +- .../webgpu/bert/flash_attention.wgsl.template | 2 +- .../flash_attention_decode_qkt.wgsl.template | 2 +- ...sh_attention_decode_split_vx.wgsl.template | 2 +- .../webgpu/bert/group_query_attention.cc | 4 +- .../contrib_ops/webgpu/moe/gate.wgsl.template | 2 +- .../core/framework/allocation_planner.cc | 3 +- .../core/framework/ort_value_name_idx_map.h | 2 +- .../contrib_ops/nhwc_inference_context.h | 7 +- onnxruntime/core/platform/telemetry.cc | 4 - onnxruntime/core/platform/telemetry.h | 2 - .../core/platform/windows/telemetry.cc | 14 - onnxruntime/core/platform/windows/telemetry.h | 2 - .../core/providers/js/operators/unary.cc | 2 +- .../nv_tensorrt_rtx/nv_execution_provider.cc | 77 +- .../qnn/builder/opbuilder/base_op_builder.cc | 3 - .../core/providers/qnn/builder/qnn_def.cc | 4 - .../core/providers/qnn/builder/qnn_def.h | 2 - .../core/providers/qnn/builder/qnn_model.cc | 4 +- .../tensorrt/tensorrt_execution_provider.cc | 77 +- .../vsinpu/builders/impl/clip_op_builder.cc | 4 +- .../core/providers/webgpu/allocator.cc | 2 +- onnxruntime/core/providers/webgpu/allocator.h | 5 - .../core/providers/webgpu/compute_context.cc | 23 +- .../core/providers/webgpu/compute_context.h | 103 +- .../webgpu/math/binary_elementwise_ops.cc | 11 +- .../core/providers/webgpu/math/gemm_packed.cc | 15 +- .../core/providers/webgpu/math/gemm_packed.h | 5 +- .../core/providers/webgpu/math/gemm_utils.cc | 46 +- .../core/providers/webgpu/math/matmul.cc | 4 +- .../providers/webgpu/math/matmul_packed.h | 5 +- .../core/providers/webgpu/math/softmax.cc | 2 +- onnxruntime/core/providers/webgpu/nn/conv.cc | 40 - onnxruntime/core/providers/webgpu/nn/conv.h | 7 - .../core/providers/webgpu/nn/conv2d_mm.cc | 5 +- .../core/providers/webgpu/nn/conv2d_mm.h | 5 +- .../core/providers/webgpu/tensor/slice.cc | 22 +- .../core/providers/webgpu/tensor/transpose.cc | 2 +- .../core/providers/webgpu/tensor/transpose.h | 2 +- .../core/providers/webgpu/webgpu_context.cc | 18 +- .../core/providers/webgpu/webgpu_context.h | 21 +- .../webgpu/webgpu_execution_provider.cc | 14 +- .../core/providers/webgpu/webgpu_kernel.cc | 47 +- .../core/providers/webgpu/webgpu_kernel.h | 33 - .../core/providers/webgpu/webgpu_utils.cc | 15 +- .../core/providers/webgpu/webgpu_utils.h | 5 +- onnxruntime/core/session/inference_session.cc | 2 - onnxruntime/core/session/utils.cc | 1 - .../providers/cuda/cuda_mempool_arena_test.cc | 15 +- .../nv_tensorrt_rtx/nv_basic_test.cc | 42 - onnxruntime/test/providers/qnn/README.md | 70 -- .../test/providers/qnn/qnn_test_utils.cc | 60 -- .../test/providers/qnn/qnn_test_utils.h | 147 +-- .../providers/tensorrt/tensorrt_basic_test.cc | 49 +- .../test/testdata/node_output_not_used.onnx | Bin 189 -> 0 bytes .../test/testdata/node_output_not_used.py | 43 - .../topk_and_multiple_graph_outputs.onnx | Bin 393 -> 0 bytes .../topk_and_multiple_graph_outputs.py | 78 -- .../github/linux/build_rocm_c_api_package.sh | 40 + .../docker/scripts/setup_rocm_yum_repo.sh | 43 + 161 files changed, 8816 insertions(+), 1176 deletions(-) create mode 100644 dockerfiles/Dockerfile.rocm create mode 100644 dockerfiles/scripts/install_rocm_deps.sh create mode 100644 onnxruntime/contrib_ops/rocm/bert/attention.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/attention.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_impl.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_impl.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_softmax.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh create mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh create mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh create mode 100644 onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh create mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc create mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh create mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh create mode 100644 onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/group_query_attention.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh create mode 100644 onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/multihead_attention.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc create mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/transformer_common.cc create mode 100644 onnxruntime/contrib_ops/rocm/bert/transformer_common.h create mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh create mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh create mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu create mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu create mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h create mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu create mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh create mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py create mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h create mode 100644 onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc create mode 100644 onnxruntime/contrib_ops/rocm/fused_conv.cc create mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8.cu create mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh create mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu create mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu create mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu create mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu create mode 100644 onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu create mode 100644 onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc create mode 100644 onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h delete mode 100644 onnxruntime/test/providers/qnn/README.md delete mode 100644 onnxruntime/test/testdata/node_output_not_used.onnx delete mode 100644 onnxruntime/test/testdata/node_output_not_used.py delete mode 100644 onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx delete mode 100644 onnxruntime/test/testdata/topk_and_multiple_graph_outputs.py create mode 100755 tools/ci_build/github/linux/build_rocm_c_api_package.sh create mode 100755 tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index f12eadc2ce794..7f7ff74959d52 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -27,7 +27,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -112,7 +112,7 @@ jobs: android_nnapi_ep: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Use jdk 17 uses: actions/setup-java@v5 @@ -187,7 +187,7 @@ jobs: name: Android CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Use jdk 17 uses: actions/setup-java@v5 diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index ddf4a52a0ccb0..30f832f67c5ee 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -12,7 +12,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Check out a copy of the repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@2.0.0 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 1db84400c272a..d33e4d923a0bc 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index d8f13d13d3f88..04177b11e9c30 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -15,7 +15,7 @@ jobs: name: "Validation" runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - uses: gradle/actions/wrapper-validation@v5 concurrency: group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} diff --git a/.github/workflows/ios.yml b/.github/workflows/ios.yml index ed572aa339ce9..0d2046b980783 100644 --- a/.github/workflows/ios.yml +++ b/.github/workflows/ios.yml @@ -20,7 +20,7 @@ jobs: runs-on: macos-14 steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5c618dc5787a5..5aaab5f8e1a10 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: name: Optional Lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: misspell # Check spellings as well uses: reviewdog/action-misspell@v1 with: @@ -42,7 +42,7 @@ jobs: contents: read security-events: write steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Setup Python uses: actions/setup-python@v6 with: @@ -87,7 +87,7 @@ jobs: name: Optional Lint C++ runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Update PATH run: | echo "$HOME/.local/bin" >> "$GITHUB_PATH" @@ -116,7 +116,7 @@ jobs: name: Lint JavaScript runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - uses: actions/setup-node@v6 with: node-version: 20 diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 5763b9c39bcc6..2370c631b7a7a 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -49,7 +49,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: recursive diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index e7e3be8c5f9ed..886705471b7de 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -48,7 +48,7 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index 4d9579a746892..af86975ee6cdc 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -28,7 +28,7 @@ jobs: packages: write steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -65,7 +65,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 @@ -122,7 +122,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 @@ -156,7 +156,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 @@ -188,7 +188,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 @@ -222,7 +222,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 @@ -286,7 +286,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -363,7 +363,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -430,7 +430,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -505,7 +505,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v6 diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 47b7c1ba7e889..0e26576829e94 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -48,7 +48,7 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 # --- Build the Docker image needed for testing --- - name: Build Docker Image for Testing diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 8ba87bc1f731c..e545406d8d20f 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -76,7 +76,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' @@ -124,7 +124,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index 8e1d0264496f6..329584c68d7d1 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -75,7 +75,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index 7ca330742f69b..abe627f4ff7bc 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -24,7 +24,7 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Setup Python uses: actions/setup-python@v6 with: diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index d9fb72271967f..25b7899584bbf 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate C/C++ API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Install doxygen and dependencies run: | sudo apt update diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index dd55bbd917337..34b9c1af9552f 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -24,7 +24,7 @@ jobs: env: DOCFXVERSION: 2.62.2 steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Install DocFX run: | dotnet tool update -g docfx diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 81defeae518a3..656d0627ed17d 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Java docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Set up JDK 11 uses: actions/setup-java@v5 with: diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index 9da78d7d9ed9c..e71d3b3c57a4b 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate JS API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Setup Node.js uses: actions/setup-node@v6 with: diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index a73b62eba6050..983d3d478a49d 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Objective-C API docs runs-on: macos-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index e35e6a04adbef..389d1683fb1ff 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate Python API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 - name: Install tools run: | sudo apt-get update diff --git a/.github/workflows/react_native.yml b/.github/workflows/react_native.yml index 4a56dfbd35406..343186b1aec8c 100644 --- a/.github/workflows/react_native.yml +++ b/.github/workflows/react_native.yml @@ -20,7 +20,7 @@ jobs: aar_path: ${{ runner.temp }}/.artifacts steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false @@ -75,7 +75,7 @@ jobs: run: echo "ANDROID_AVD_HOME=${{ runner.temp }}/android-avd" >> $GITHUB_ENV - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Use Python 3.12 uses: actions/setup-python@v6 @@ -175,7 +175,7 @@ jobs: timeout-minutes: 120 steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Use Xcode 15.3.0 run: sudo xcode-select --switch /Applications/Xcode_15.3.0.app/Contents/Developer @@ -218,7 +218,7 @@ jobs: timeout-minutes: 90 steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Download iOS pod artifact uses: actions/download-artifact@v6 diff --git a/.github/workflows/reusable_linux_build.yml b/.github/workflows/reusable_linux_build.yml index f0da87647b8b0..795e35b06bfb0 100644 --- a/.github/workflows/reusable_linux_build.yml +++ b/.github/workflows/reusable_linux_build.yml @@ -75,7 +75,7 @@ jobs: id-token: write steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Set up Python ${{ inputs.python_version }} if: inputs.architecture != 'arm64' diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml index 6ae25ccc0bf3e..016feab5e0d94 100644 --- a/.github/workflows/web.yml +++ b/.github/workflows/web.yml @@ -22,7 +22,7 @@ jobs: commit_sha: ${{ steps.extract_commit.outputs.commit_sha }} steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: true diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index c16ce6eb222eb..eee98332056f6 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_build_x64_asan.yml b/.github/workflows/windows_build_x64_asan.yml index ac5f08717155f..05fd4acd4de9a 100644 --- a/.github/workflows/windows_build_x64_asan.yml +++ b/.github/workflows/windows_build_x64_asan.yml @@ -19,7 +19,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 5d6e9b1da31a2..fd5b65eb039a3 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -21,7 +21,7 @@ jobs: name: Windows GPU CUDA CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' @@ -152,7 +152,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index 0abf6b650f986..e8ee7751348b4 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -27,7 +27,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 with: fetch-depth: 0 # Fetch all history for all tags and branches submodules: 'none' diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index 537ff1fb00071..b608c0879aa45 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml index f6176164354bb..4f0b50e65df6e 100644 --- a/.github/workflows/windows_qnn_x64.yml +++ b/.github/workflows/windows_qnn_x64.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Setup Python uses: actions/setup-python@v6 diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index 4a564a3b1cb36..229efb01f0018 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -21,7 +21,7 @@ jobs: name: Windows GPU TensorRT CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-latest"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' @@ -157,7 +157,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index f729cda5ea576..899a8b66eac7a 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -34,7 +34,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none @@ -156,7 +156,7 @@ jobs: timeout-minutes: 300 steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none @@ -209,7 +209,7 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index 385d03c1a6705..d62c7130e0ebb 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index ee045b70b6efa..a2991bb0f1131 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index 25dfc41e6922c..bb6c5035b0dce 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index e738db262f3a2..4378231338673 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index 5672e4043c624..b453cd570ac05 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index 381d9dda5cd42..d20778d56f60b 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: submodules: false diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm new file mode 100644 index 0000000000000..aca8c3feaff71 --- /dev/null +++ b/dockerfiles/Dockerfile.rocm @@ -0,0 +1,24 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to run ONNXRuntime with ROCm integration +#-------------------------------------------------------------------------- + +FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0 + +ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime +ARG ONNXRUNTIME_BRANCH=main + +WORKDIR /code + +ENV PATH=/code/cmake-3.27.3-linux-x86_64/bin:${PATH} + +# Prepare onnxruntime repository & build onnxruntime +RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ + /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ + cd onnxruntime &&\ + /bin/sh ./build.sh --allow_running_as_root --config Release --build_wheel --update --build --parallel --cmake_extra_defines\ + ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm &&\ + pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ + cd .. diff --git a/dockerfiles/README.md b/dockerfiles/README.md index 88c542b63ccd2..4c69098103edd 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -1,8 +1,9 @@ # Dockerfiles **Execution Providers** - CPU: [Dockerfile](Dockerfile.source), [Instructions](#cpu) -- CUDA: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) +- CUDA/cuDNN: [Dockerfile](Dockerfile.cuda), [Instructions](#cuda) - MIGraphX: [Dockerfile](Dockerfile.migraphx), [Instructions](#migraphx) +- ROCm: [Dockerfile](Dockerfile.rocm), [Instructions](#rocm) - OpenVINO: [Dockerfile](Dockerfile.openvino), [Instructions](#openvino) - TensorRT: [Dockerfile](Dockerfile.tensorrt), [Instructions](#tensorrt) - VitisAI: [Dockerfile](Dockerfile.vitisai) @@ -303,3 +304,17 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-migraphx ``` + + ## ROCm +**Ubuntu 22.04, ROCm6.2.3** + +1. Build the docker image from the Dockerfile in this repository. + ``` + docker build -t onnxruntime-rocm -f Dockerfile.rocm . + ``` + +2. Run the Docker image + + ``` + docker run -it --device=/dev/kfd --device=/dev/dri --group-add video onnxruntime-rocm + ``` diff --git a/dockerfiles/scripts/install_rocm_deps.sh b/dockerfiles/scripts/install_rocm_deps.sh new file mode 100644 index 0000000000000..fd445be87479b --- /dev/null +++ b/dockerfiles/scripts/install_rocm_deps.sh @@ -0,0 +1,84 @@ +#!/bin/bash +prefix=/opt/rocm +DEBIAN_FRONTEND=noninteractive +apt-get update && apt-get install -y --no-install-recommends \ + wget \ + zip \ + ca-certificates \ + build-essential \ + curl \ + libcurl4-openssl-dev \ + libssl-dev \ + python3-dev + +# rocm-cmake +rocm_cmake_version=4.5.2 +wget --quiet https://github.com/RadeonOpenCompute/rocm-cmake/archive/refs/tags/rocm-${rocm_cmake_version}.tar.gz +tar -xzvf rocm-${rocm_cmake_version}.tar.gz +rm rocm-${rocm_cmake_version}.tar.gz +cd rocm-cmake-rocm-${rocm_cmake_version} +mkdir build +cd build +cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rocm-cmake-rocm-${rocm_cmake_version} + +# rccl +rccl_version=4.5.2 +wget --quiet https://github.com/ROCmSoftwarePlatform/rccl/archive/refs/tags/rocm-${rccl_version}.tar.gz +tar -xzvf rocm-${rccl_version}.tar.gz +rm rocm-${rccl_version}.tar.gz +cd rccl-rocm-${rccl_version} +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rccl-rocm-${rccl_version} + +#rocrand +rocrand_version=4.5.2 +wget --quiet https://github.com/ROCmSoftwarePlatform/rocRAND/archive/refs/tags/rocm-${rocrand_version}.tar.gz +tar -xzvf rocm-${rocrand_version}.tar.gz +rm rocm-${rocrand_version}.tar.gz +cd rocRAND-rocm-${rocrand_version} +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rocRAND-rocm-${rocrand_version} + +#hipcub +hipcub_version=4.5.2 +wget --quiet https://github.com/ROCmSoftwarePlatform/hipCUB/archive/refs/tags/rocm-${hipcub_version}.tar.gz +tar -xzvf rocm-${hipcub_version}.tar.gz +rm rocm-${hipcub_version}.tar.gz +cd hipCUB-rocm-${hipcub_version} +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make package +make install +cd ../.. +rm -rf hipCUB-rocm-${hipcub_version} + +#rocprim +rocprim_version=4.5.2 +wget --quiet https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/refs/tags/rocm-${rocprim_version}.tar.gz +tar -xzvf rocm-${rocprim_version}.tar.gz +rm rocm-${rocprim_version}.tar.gz +cd rocPRIM-rocm-${rocprim_version} +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rocPRIM-rocm-${rocprim_version} + diff --git a/js/package-lock.json b/js/package-lock.json index 0fca515b61238..1e9f5cb29fe6c 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4,7 +4,6 @@ "requires": true, "packages": { "": { - "name": "js", "license": "MIT", "devDependencies": { "@eslint/compat": "^1.4.0", @@ -3231,27 +3230,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/glob": { - "version": "10.5.0", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", - "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", - "dev": true, - "license": "ISC", - "dependencies": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "bin": { - "glob": "dist/esm/bin.mjs" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -3264,32 +3242,6 @@ "node": ">=10.13.0" } }, - "node_modules/glob/node_modules/brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "balanced-match": "^1.0.0" - } - }, - "node_modules/glob/node_modules/minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "license": "ISC", - "dependencies": { - "brace-expansion": "^2.0.1" - }, - "engines": { - "node": ">=16 || 14 >=14.17" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/global-agent": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz", @@ -4359,6 +4311,43 @@ "balanced-match": "^1.0.0" } }, + "node_modules/mocha/node_modules/glob": { + "version": "10.4.5", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", + "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", + "dev": true, + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/mocha/node_modules/glob/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/mocha/node_modules/minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", @@ -8089,40 +8078,6 @@ "get-intrinsic": "^1.2.6" } }, - "glob": { - "version": "10.5.0", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", - "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", - "dev": true, - "requires": { - "foreground-child": "^3.1.0", - "jackspeak": "^3.1.2", - "minimatch": "^9.0.4", - "minipass": "^7.1.2", - "package-json-from-dist": "^1.0.0", - "path-scurry": "^1.11.1" - }, - "dependencies": { - "brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", - "dev": true, - "requires": { - "balanced-match": "^1.0.0" - } - }, - "minimatch": { - "version": "9.0.5", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", - "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, - "requires": { - "brace-expansion": "^2.0.1" - } - } - } - }, "glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -8817,6 +8772,31 @@ "balanced-match": "^1.0.0" } }, + "glob": { + "version": "10.4.5", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", + "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", + "dev": true, + "requires": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "dependencies": { + "minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "requires": { + "brace-expansion": "^2.0.1" + } + } + } + }, "minimatch": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index de8d631362db7..e6ed2bdb9e17b 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -33,7 +33,6 @@ "version": "1.24.0", "license": "MIT", "devDependencies": { - "globby": "^15.0.0", "typedoc": "^0.25.7" } }, @@ -62,15 +61,15 @@ } }, "node_modules/@babel/code-frame": { - "version": "7.27.1", - "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", - "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", + "version": "7.26.2", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", + "integrity": "sha512-RJlIHRueQgwWitWgF8OdFYGZX328Ax5BCemNGlqHfplnRT9ESi8JkFlvaVYbS+UubVY6dpv87Fs2u5M29iNFVQ==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-validator-identifier": "^7.27.1", + "@babel/helper-validator-identifier": "^7.25.9", "js-tokens": "^4.0.0", - "picocolors": "^1.1.1" + "picocolors": "^1.0.0" }, "engines": { "node": ">=6.9.0" @@ -411,9 +410,9 @@ } }, "node_modules/@babel/helper-string-parser": { - "version": "7.27.1", - "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", - "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", + "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", "dev": true, "license": "MIT", "engines": { @@ -421,9 +420,9 @@ } }, "node_modules/@babel/helper-validator-identifier": { - "version": "7.28.5", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", - "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", + "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", "dev": true, "license": "MIT", "engines": { @@ -456,27 +455,27 @@ } }, "node_modules/@babel/helpers": { - "version": "7.28.4", - "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.28.4.tgz", - "integrity": "sha512-HFN59MmQXGHVyYadKLVumYsA9dBFun/ldYxipEjzA4196jpLZd8UjEEBLkbEkvfYreDqJhZxYAWFPtrfhNpj4w==", + "version": "7.25.6", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.25.6.tgz", + "integrity": "sha512-Xg0tn4HcfTijTwfDwYlvVCl43V6h4KyVVX2aEm4qdO/PC6L2YvzLHFdmxhoeSA3eslcE6+ZVXHgWwopXYLNq4Q==", "dev": true, "license": "MIT", "dependencies": { - "@babel/template": "^7.27.2", - "@babel/types": "^7.28.4" + "@babel/template": "^7.25.0", + "@babel/types": "^7.25.6" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/parser": { - "version": "7.28.5", - "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.5.tgz", - "integrity": "sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==", + "version": "7.26.9", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.9.tgz", + "integrity": "sha512-81NWa1njQblgZbQHxWHpxxCzNsa3ZwvFqpUg7P+NNUU6f3UU2jBEg4OlF/J6rl8+PQGh1q6/zWScd001YwcA5A==", "dev": true, "license": "MIT", "dependencies": { - "@babel/types": "^7.28.5" + "@babel/types": "^7.26.9" }, "bin": { "parser": "bin/babel-parser.js" @@ -2115,25 +2114,35 @@ } }, "node_modules/@babel/runtime": { - "version": "7.28.4", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz", - "integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==", + "version": "7.25.6", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.6.tgz", + "integrity": "sha512-VBj9MYyDb9tuLq7yzqjgzt6Q+IBQLrGZfdjOekyEirZPHxXWoTSGUTMrpsfi58Up73d13NfYLv8HT9vmznjzhQ==", "dev": true, "license": "MIT", + "dependencies": { + "regenerator-runtime": "^0.14.0" + }, "engines": { "node": ">=6.9.0" } }, + "node_modules/@babel/runtime/node_modules/regenerator-runtime": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", + "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==", + "dev": true, + "license": "MIT" + }, "node_modules/@babel/template": { - "version": "7.27.2", - "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", - "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", + "version": "7.26.9", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.26.9.tgz", + "integrity": "sha512-qyRplbeIpNZhmzOysF/wFMuP9sctmh2cFzRAZOn1YapxBsE1i9bJIY586R/WBLfLcmcBlM8ROBiQURnnNy+zfA==", "dev": true, "license": "MIT", "dependencies": { - "@babel/code-frame": "^7.27.1", - "@babel/parser": "^7.27.2", - "@babel/types": "^7.27.1" + "@babel/code-frame": "^7.26.2", + "@babel/parser": "^7.26.9", + "@babel/types": "^7.26.9" }, "engines": { "node": ">=6.9.0" @@ -2180,14 +2189,14 @@ "license": "MIT" }, "node_modules/@babel/types": { - "version": "7.28.5", - "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.5.tgz", - "integrity": "sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==", + "version": "7.26.9", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.9.tgz", + "integrity": "sha512-Y3IR1cRnOxOCDvMmNiym7XpXQ93iGDDPHx+Zj+NM+rg0fBaShfQLkg+hKPaZCEvg5N/LeCo4+Rj/i3FuJsIQaw==", "dev": true, "license": "MIT", "dependencies": { - "@babel/helper-string-parser": "^7.27.1", - "@babel/helper-validator-identifier": "^7.28.5" + "@babel/helper-string-parser": "^7.25.9", + "@babel/helper-validator-identifier": "^7.25.9" }, "engines": { "node": ">=6.9.0" @@ -3310,9 +3319,9 @@ } }, "node_modules/babel-plugin-module-resolver/node_modules/brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", + "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", "dev": true, "license": "MIT", "dependencies": { @@ -3468,9 +3477,7 @@ } }, "node_modules/brace-expansion": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", - "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "version": "1.1.11", "dev": true, "license": "MIT", "dependencies": { @@ -3824,9 +3831,9 @@ } }, "node_modules/compression": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz", - "integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==", + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.0.tgz", + "integrity": "sha512-k6WLKfunuqCYD3t6AsuPGvQWaKwuLLh2/xHNcX4qE+vIfDNXpSqnrhwA7O53R7WVQUnt8dVAIW+YHr7xTgOgGA==", "dev": true, "license": "MIT", "dependencies": { @@ -3834,7 +3841,7 @@ "compressible": "~2.0.18", "debug": "2.6.9", "negotiator": "~0.6.4", - "on-headers": "~1.1.0", + "on-headers": "~1.0.2", "safe-buffer": "5.2.1", "vary": "~1.1.2" }, @@ -4814,9 +4821,9 @@ } }, "node_modules/image-size": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.1.tgz", - "integrity": "sha512-rH+46sQJ2dlwfjfhCyNx5thzrv+dtmBIhPHk0zgRUukHzZ/kRueTJXoYYsclBaKcSMBWuGbOFXtioLpzTb5euw==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/image-size/-/image-size-1.2.0.tgz", + "integrity": "sha512-4S8fwbO6w3GeCVN6OPtA9I5IGKkcDMPcKndtUlpJuCwu7JLjtj7JZpwqLuyY2nrmQT3AWsCJLSKPsc2mPBSl3w==", "dev": true, "license": "MIT", "dependencies": { @@ -5243,9 +5250,7 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "3.14.2", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.14.2.tgz", - "integrity": "sha512-PMSmkqxr106Xa156c2M265Z+FTrPl+oxd/rgOQy2tijQeK5TxQ43psO1ZCwhVOSdnn+RzkzlRz/eY4BgJBYVpg==", + "version": "3.14.1", "dev": true, "license": "MIT", "dependencies": { @@ -6539,9 +6544,9 @@ } }, "node_modules/on-headers": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz", - "integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==", + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", + "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", "dev": true, "license": "MIT", "engines": { @@ -7125,9 +7130,9 @@ "license": "Python-2.0" }, "node_modules/react-native-builder-bob/node_modules/brace-expansion": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", - "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", + "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", "dev": true, "license": "MIT", "dependencies": { @@ -7198,9 +7203,9 @@ } }, "node_modules/react-native-builder-bob/node_modules/js-yaml": { - "version": "4.1.1", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", - "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", + "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index f0f7527f665b9..6a8dffb73fa08 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -360,7 +360,7 @@ const createInPlaceSoftmaxProgramInfo = ( let local_offset = local_idx * uniforms.elements_per_thread; let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset; let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'}; - var thread_max_vector = ${f32Type}(-3.4028234663852886e+38f); + var thread_max_vector = ${f32Type}(-3.402823e+38f); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } @@ -378,7 +378,7 @@ const createInPlaceSoftmaxProgramInfo = ( })()}; workgroupBarrier(); - var max_value = f32(-3.4028234663852886e+38f); + var max_value = f32(-3.402823e+38f); for (var i = 0u; i < ${WG}; i++) { max_value = max(thread_max[i], max_value); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index f6882280e91df..2056416873df5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -81,7 +81,7 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt // 6.2.4 in wgsl spec const threadMaxDecl = tensorTypeToWsglStorageType(transposedInput.dataType) === 'f32' - ? `var threadMax = ${valueType}(-3.4028234663852886e+38f);` + ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (shaderHelper: ShaderHelper) => ` var rowMaxShared : ${valueType}; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu new file mode 100644 index 0000000000000..b40fc2bf0eef8 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cu @@ -0,0 +1,215 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/attention.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" +#include "contrib_ops/rocm/bert/transformer_common.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/shared_inc/fpgeneric.h" +#include "core/providers/rocm/tunable/gemm.h" + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +constexpr int kPastSequenceLengthInputIndex = 6; +constexpr int kPastInputIndex = 4; +constexpr int kPresentOutputIndex = 1; + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Attention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ + Attention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +Attention::Attention(const OpKernelInfo& info) + : RocmKernel(info), AttentionBase(info, true), attn_type_(kAttention) { + using HipT = typename ToHipType::MappedType; + using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; + tunable_op_ = std::make_shared(); +} + +template +Status Attention::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* weights = context->Input(1); + const Tensor* bias = context->Input(2); + const Tensor* mask_index = context->Input(3); + const Tensor* past = context->Input(4); + const Tensor* attention_bias = context->Input(5); + const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); + + auto& device_prop = GetDeviceProp(); + RocmAttentionParameters attn; + ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), + weights->Shape(), + bias->Shape(), + mask_index, + past, + attention_bias, + &attn, + device_prop.maxThreadsPerBlock, + past_seq_len)); + ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention + ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(attn.batch_size); + output_shape[1] = static_cast(attn.sequence_length); + output_shape[2] = static_cast(attn.v_hidden_size); + Tensor* output = context->Output(0, output_shape); + + std::vector present_dims{ + 2, attn.batch_size, attn.num_heads, + past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, + attn.head_size}; + TensorShape present_shape(present_dims); + Tensor* present = context->Output(kPresentOutputIndex, present_shape); + + auto stream = Stream(context); + hipblasHandle_t hipblas = GetHipblasHandle(context); + + using HipT = typename ToHipType::MappedType; + using QkvProjectGeneric = GemmPermuteGenericPipeline; + using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline; + using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; + + ORT_RETURN_IF_ERROR(ClassifyAttentionMode(attn_type_, &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present})); + ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE || + attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE || + attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE || + attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE || + attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE); + + size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn); + size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn), + AttentionGeneric::GetWorkspaceNumBytes(&attn)); + if (GetTuningContext()->IsTunableOpEnabled()) { + shared_workspace_bytes = std::max(shared_workspace_bytes, AttentionTunableOp::GetWorkspaceNumBytes(&attn)); + } + + auto qkv_project_output = GetScratchBuffer(qkv_project_output_bytes, context->GetComputeStream()); + auto workspace = GetScratchBuffer(shared_workspace_bytes, context->GetComputeStream()); + + GemmPermuteParams gemm_permute_params; + { + auto& params = gemm_permute_params; + params.tuning_ctx = GetTuningContext(); + params.stream = context->GetComputeStream(); + params.handle = hipblas; + params.attention = &attn; + params.device_prop = &device_prop; + + params.input_buffer = reinterpret_cast(input->DataRaw()); + params.weight_buffer = reinterpret_cast(weights->DataRaw()); + params.bias_buffer = reinterpret_cast(bias->DataRaw()); + params.out_buffer = reinterpret_cast(qkv_project_output.get()); + params.ones = GetConstOnes(attn.batch_size * attn.sequence_length, stream); + params.workspace_buffer = reinterpret_cast(workspace.get()); + } + + ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params)); + auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params); + + // NOTE: GemmPermute always output 3BNSH, k_buffer and v_buffer can be treated as 2BNSH + if (nullptr != present) { + Strides dst_strides; // the output buffer is present Tensor, the buffer is the same + + int4 add_shape{2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size}; + HipT* add_dest = nullptr; // destination of concatenated data to present + const HipT* const add_src = k_buffer; // source of concatenated data to present + const auto add_src_strides = Strides::BNSHMemory( + 2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size); + + if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; + } else if (attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + + // We only need to copy past to present in this case. All other cases will be build the present incrementally + const int4 past_shape = {2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; + HipT* const past_dest = reinterpret_cast(present->MutableDataRaw()); + const HipT* const past_src = reinterpret_cast(past->DataRaw()); + const Strides past_src_strides = Strides::BNSHMemory( + 2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); + + ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, past_src, past_shape, past_src_strides.ForBNSHCoord(), + past_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } else if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; + } else if (attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + } + + ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, add_src, add_shape, add_src_strides.ForBNSHCoord(), + add_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + + // update pointers to present_k and present_v. TODO: switch to ConvertToOffsetedBufferViews + k_buffer = reinterpret_cast(present->MutableDataRaw()); + v_buffer = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(attn.batch_size, 0, 0, 0); + } + + // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax + const TransformerOptions* options = TransformerOptions::GetInstance(); + bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); + + GemmSoftmaxGemmPermuteParams gemm_softmax_gemm_permute_params; + { + auto& params = gemm_softmax_gemm_permute_params; + params.tuning_ctx = GetTuningContext(); + params.stream = context->GetComputeStream(); + params.handle = hipblas; + params.attention = &attn; + params.device_prop = &device_prop; + // FIXME: the params.scale seems to be different from AttentionParameters::scale; + params.scale = 1.0f / sqrt(static_cast(attn.head_size)); + // TODO: switch to ConvertToOffsetedBufferViews + params.q_buffer = q_buffer; + params.k_buffer = k_buffer; + params.v_buffer = v_buffer; + params.out_buffer = reinterpret_cast(output->MutableDataRaw()); + + if (attention_bias != nullptr) { + params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); + } + + if (mask_index != nullptr) { + params.mask_index_buffer = mask_index->Data(); + params.mask_index_dims = mask_index->Shape().AsShapeVector(); + } + + params.workspace_buffer = reinterpret_cast(workspace.get()); + } + + if (this->GetTuningContext()->IsTunableOpEnabled() && + !use_persistent_softmax) { + return (*std::static_pointer_cast(tunable_op_))(&gemm_softmax_gemm_permute_params); + } else { + return AttentionGeneric::Run(&gemm_softmax_gemm_permute_params, use_persistent_softmax); + } +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.h b/onnxruntime/contrib_ops/rocm/bert/attention.h new file mode 100644 index 0000000000000..7204fd660a516 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/rocm_kernel.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/rocm/bert/attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class Attention final : public RocmKernel, public AttentionBase { + public: + Attention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + public: + AttentionType attn_type_; + + // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: + // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. + // 2. We don't want to construct the object repeatly (which is expansive) during Compute. + std::shared_ptr tunable_op_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu new file mode 100644 index 0000000000000..270a8e51daf88 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -0,0 +1,435 @@ +/* + The implementation of this file is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/shared_inc/fpgeneric.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/attention_softmax.h" +#include "contrib_ops/rocm/bert/decoder_attention_impl.h" + +using namespace onnxruntime::rocm; + +namespace blas = onnxruntime::rocm::tunable::blas; + +#define CHECK_ROCM(expr) HIP_RETURN_IF_ERROR(expr) + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +static size_t AlignTo(size_t a, size_t b) { + return CeilDiv(a, b) * b; +} + +size_t GetAttentionScratchSize(size_t element_size, + int batch_size, + int num_heads, + int sequence_length, + int total_sequence_length) { + const size_t bytes = element_size * batch_size * num_heads * sequence_length * total_sequence_length; + + const size_t alignment = 256; + const size_t bytesAligned = AlignTo(bytes, alignment); + return bytesAligned; +} + +size_t GetAttentionWorkspaceSize( + size_t element_size, + int batch_size, + int num_heads, + int head_size, + int sequence_length, + int total_sequence_length) { + size_t qkv_size = element_size * 3 * batch_size * sequence_length * num_heads * head_size; + return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, + sequence_length, total_sequence_length); +} + +inline int3 Get2DMaskStrides(int total_sequence_length) { + // stride == 0 indicate broadcasting + return {total_sequence_length, 0, 1}; +} + +Status ClassifyAttentionMode( + AttentionType attn_type, + RocmAttentionParameters* attn, + const std::vector& qkv, + const std::vector& past, + const std::vector& present) { + size_t num_qkv = std::count_if(qkv.cbegin(), qkv.cend(), [](auto it) { return it != nullptr; }); + size_t num_past = std::count_if(past.cbegin(), past.cend(), [](auto it) { return it != nullptr; }); + size_t num_present = std::count_if(present.cbegin(), present.cend(), [](auto it) { return it != nullptr; }); + + auto hint = MakeString(num_qkv, " qkv inputs, ", num_past, " past inputs and ", num_present, " present inputs"); + LOGS_DEFAULT(VERBOSE) << hint; + + if (attn_type == kAttention) { + ORT_ENFORCE(num_qkv == 0); + if (num_past == 0 && num_present == 0) { + attn->mode = QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE; + return Status::OK(); + } else if (num_past == 0 && num_present == 1) { + if (attn->past_present_share_buffer == false) { + attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE; + return Status::OK(); + } else { + attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE; + return Status::OK(); + } + } else if (num_past == 1 && num_present == 1) { + if (attn->past_present_share_buffer == false) { + attn->mode = QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE; + return Status::OK(); + } else { + attn->mode = QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE; + return Status::OK(); + } + } + } else if (attn_type == kMultiHeadAttention || attn_type == kDecoderMaskedMultiHeadAttention) { + if (num_qkv == 3 && num_past == 0 && num_present == 0) { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; + return Status::OK(); + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; + return Status::OK(); + } + } else if (num_qkv == 3 && num_past == 0 && num_present == 2) { + if (attn->past_present_share_buffer == false) { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; + return Status::OK(); + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; + return Status::OK(); + } + } else { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; + return Status::OK(); + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; + return Status::OK(); + } + } + } else if (num_qkv == 3 && num_past == 2 && num_present == 2) { + if (attn->past_present_share_buffer == false) { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; + return Status::OK(); + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; + return Status::OK(); + } + } else { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; + return Status::OK(); + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; + return Status::OK(); + } + } + } else if (num_qkv == 1 && num_past == 0 && num_present == 0) { + if (attn->qkv_format == QKV_BSN3H) { + attn->mode = BLN3H_NONE_NONE_NONE_NONE_NONE_NONE; + return Status::OK(); + } + } else if (num_qkv == 2 && num_past == 0 && num_present == 0) { + if (attn->qkv_format == Q_KV_BSNH_BSN2H) { + attn->mode = BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE; + return Status::OK(); + } + } + } + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported AttentionMode for ", attn_type, ". Got qkv format ", attn->qkv_format, + ". Got ", hint); +} + +template +Status DecoderQkvToContext( + const hipDeviceProp_t& prop, + RocmTuningContext* tuning_ctx, + Stream* ort_stream, + hipblasHandle_t& hipblas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const T* gemm_query_buffer, + const T* gemm_kv_buffer, + const bool* key_padding_mask, + const T* key_cache, + const T* value_cache, + T* qkv_buffer, + T* workspace_buffer, + T* output, + T* new_key_cache, + T* new_value_cache) { + const int max_threads_per_block = prop.maxThreadsPerBlock; + const int BN = batch_size * num_heads; + const int BHN = BN * head_size; + const int BNS = BN * sequence_length; + const int k_buffer_offset = sequence_length * BHN; + const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + + T* temp_qkv_buffer = workspace_buffer; + auto stream = static_cast(ort_stream->GetHandle()); + + const T* q = qkv_buffer; + // transpose q and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, + num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); + + const T* k = qkv_buffer + k_buffer_offset; + const T* v = qkv_buffer + v_buffer_offset; + if (!has_layer_state || !use_past) { + if (!static_kv) { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } else { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } + } else { + if (!static_kv) { + // transpose kv and copy them to temp_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); + // concat cache-k with k and copy to qkv_buffer + if (nullptr != key_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, + batch_size, head_size, num_heads, + max_threads_per_block, 1, key_cache, + temp_qkv_buffer, qkv_buffer + k_buffer_offset)); + } + // concat cache-v with v and copy to qkv_buffer + if (nullptr != value_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, + batch_size, head_size, num_heads, + max_threads_per_block, 1, value_cache, + temp_qkv_buffer + k_buffer_offset, + qkv_buffer + v_buffer_offset)); + } + } + } + + if (has_layer_state) { + if (use_past && static_kv) { + CHECK_ROCM(hipMemcpyAsync(new_key_cache, key_cache, + kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); + CHECK_ROCM(hipMemcpyAsync(new_value_cache, value_cache, + kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); + } else { + CHECK_ROCM(hipMemcpyAsync(new_key_cache, k, + kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); + CHECK_ROCM(hipMemcpyAsync(new_value_cache, v, + kv_sequence_length * BHN * sizeof(T), hipMemcpyDeviceToDevice, stream)); + } + } + + // scratch1: BxNxSxS* buffer + // scratch2: BxNxSxS* buffer + // scratch3: BxNxSxH buffer + T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; + T* scratch2 = scratch1 + BNS * kv_sequence_length; + T* scratch3 = scratch2 + BNS * kv_sequence_length; + + // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* + // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* + const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); + const int temp_matrix_size = sequence_length * kv_sequence_length; + + const int strideA = kv_sequence_length * head_size; + const int strideB = sequence_length * head_size; + if (use_past && static_kv) { + ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( + tuning_ctx, ort_stream, hipblas, + blas::BlasOp::Trans, blas::BlasOp::NonTrans, + kv_sequence_length, sequence_length, head_size, + /*alpha=*/rsqrt_head_size, + key_cache, head_size, strideA, + q, head_size, strideB, + /*beta=*/0.0f, + scratch1, kv_sequence_length, temp_matrix_size, + BN)); + } else { + ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( + tuning_ctx, ort_stream, hipblas, + blas::BlasOp::Trans, blas::BlasOp::NonTrans, + kv_sequence_length, sequence_length, head_size, + /*alpha=*/rsqrt_head_size, + k, head_size, strideA, + q, head_size, strideB, + /*beta=*/0.0f, + scratch1, kv_sequence_length, temp_matrix_size, + BN)); + } + + if (has_key_padding_mask) { + int3 strides = Get2DMaskStrides(kv_sequence_length); + ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( + ort_stream, kv_sequence_length, sequence_length, batch_size, num_heads, + strides, nullptr, key_padding_mask, nullptr, scratch1, scratch2, + false, 1.0f, false, nullptr, mask_filter_value)); + } else { + ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, + num_heads, nullptr, scratch1, scratch2, false)); + } + + // compute P*V (as V*P), and store in scratch3: BxNxSxH + if (use_past && static_kv) { + ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( + tuning_ctx, ort_stream, hipblas, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + head_size, sequence_length, kv_sequence_length, + /*alpha=*/1.0f, + value_cache, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + /*beta=*/0.0f, + scratch3, head_size, strideB, + BN)); + } else { + ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( + tuning_ctx, ort_stream, hipblas, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + head_size, sequence_length, kv_sequence_length, + /*alpha=*/1.0f, + v, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + /*beta=*/0.0f, + scratch3, head_size, strideB, + BN)); + } + + // scratch3 is BxNxSxH, transpose to output SxBxNxH + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, + num_heads, max_threads_per_block, true, scratch3, output); +} + +Status LaunchDecoderAttentionKernel( + const hipDeviceProp_t& prop, + RocmTuningContext* tuning_ctx, + Stream* stream, + hipblasHandle_t& hipblas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const void* gemm_query_buffer, + const void* gemm_kv_buffer, + const bool* key_padding_mask, + const void* key_cache, + const void* value_cache, + void* qkv_buffer, + void* workspace_buffer, + void* output, + void* new_key_cache, + void* new_value_cache) { + if (element_size == 2) { + return DecoderQkvToContext( + prop, + tuning_ctx, + stream, + hipblas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } else { + return DecoderQkvToContext( + prop, + tuning_ctx, + stream, + hipblas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h new file mode 100644 index 0000000000000..07d875e90fa4b --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cpu/bert/attention_parameters.h" +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +typedef struct __align__(32) { + long long int x, y, z, w; +} LongLong4; + +size_t GetAttentionScratchSize( + size_t element_size, + int batch_size, + int num_heads, + int sequence_length, + int all_sequence_length); + +size_t GetAttentionWorkspaceSize( + size_t element_size, + int batch_size, + int num_heads, + int head_size, + int sequence_length, + int past_sequence_length); + +Status LaunchTransCtx(hipStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); + +Status LaunchTransCtx(hipStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, const half* input, half* output); + +Status LaunchTransQkv(hipStream_t stream, const int matrix_num, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, const float* input, float* output, + int total_matrix_count = -1); + +Status LaunchTransQkv(hipStream_t stream, const int matrix_num, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, + int total_matrix_count = -1); + +Status LaunchConcatTensorToTensor(hipStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const float* tensor_in, + const float* tensor_add, + float* tensor_out); + +Status LaunchConcatTensorToTensor(hipStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const half* tensor_in, + const half* tensor_add, + half* tensor_out); + +inline hipblasStatus_t _compat_hipblas_gemm_strided_batched_ex(hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + const void* alpha, + const void* A, + hipDataType a_type, + int lda, + hipblasStride stride_A, + const void* b, + hipDataType b_type, + int ldb, + hipblasStride stride_b, + const void* beta, + void* c, + hipDataType c_type, + int ldc, + hipblasStride stride_c, + int batch_count, + hipblasComputeType_t compute_type, + hipblasGemmAlgo_t algo) { + return hipblasGemmStridedBatchedEx(handle, + transa, + transb, + m, // m + n, // n + k, // k + alpha, // alpha + A, // A + a_type, // A type + lda, // lda + stride_A, // strideA + b, // B + b_type, // B type + ldb, // ldb + stride_b, // strideB + beta, // beta + c, // C + c_type, // C type + ldc, // ldc + stride_c, // strideC + batch_count, // batch count + compute_type, + algo); +} + +// Compatible for CublasMathModeSetter +class CompatHipblasMathModeSetter { + public: + CompatHipblasMathModeSetter(const hipDeviceProp_t&, + hipblasHandle_t, + int) { + } +}; + +enum AttentionMode { + // Q,K,V,PastK,PastV,PresentK,PresentV + QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, + QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE, + QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE, + QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE, + QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE, + BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE, + BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE, + BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH, + BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH, + BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH, + BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH, + BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH, + BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH, + BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH, + BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH, + BLN3H_NONE_NONE_NONE_NONE_NONE_NONE, + BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE, +}; + +struct RocmAttentionParameters : AttentionParameters { + AttentionMode mode; +}; + +Status ClassifyAttentionMode(AttentionType type, + RocmAttentionParameters* attn, + const std::vector& qkv, + const std::vector& past, + const std::vector& present); + +template +Status LaunchStridedCopy( + hipStream_t stream, + const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); + +template +Status LaunchStridedCopy(hipStream_t stream, + const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h) + T* out, LongLong4 out_strides, // coord (b,n,s,h) + int max_threads_per_block); +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h new file mode 100644 index 0000000000000..9f2faa228cf79 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -0,0 +1,465 @@ +#include "hip/hip_runtime.h" +/* + The implementation of this file is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include +#include +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/math/softmax.h" + +#define ROCMRT_INF_F __int_as_float(0x7f800000) + +using namespace onnxruntime::rocm; +using namespace hipcub; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +__device__ inline void Softmax(const int all_sequence_length, + const int valid_end, + const int valid_start, + const T* attn_bias, + const T* input, + T* output) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + float thread_data_max(-ROCMRT_INF_F); + + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; + for (int i = threadIdx.x; i < valid_end; i += TPB) { + if (i >= valid_start) { + const int index = offset + i; + float input_at_idx = attn_bias == nullptr + ? static_cast(input[index]) + : static_cast(input[index] + attn_bias[index]); + if (thread_data_max < input_at_idx) { + thread_data_max = input_at_idx; + } + } + } + + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max()); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_sum(0.f); + for (int i = threadIdx.x; i < valid_end; i += TPB) { + if (i >= valid_start) { + const int index = offset + i; + float val = attn_bias == nullptr ? input[index] : input[index] + attn_bias[index]; + thread_data_sum += expf(val - max_block); + } + } + + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, hipcub::Sum()); + if (threadIdx.x == 0) { + sum_reverse_block = 1.f / sum; + } + __syncthreads(); + + for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { + const int index = offset + i; + float input_at_idx = attn_bias == nullptr + ? static_cast(input[index]) + : static_cast(input[index] + attn_bias[index]); + const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; + output[index] = T(val); + } +} + +template +__device__ inline void SoftmaxSmall(const int all_sequence_length, + const int sequence_length, + const int valid_end, + const int valid_start, + const T* attn_bias, + const T* input, + T* output, + bool causal) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; + const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; + const int index = offset + threadIdx.x; + + bool is_valid = false; // whether it has attention mask == 1. + + // Update end position for causal. + int end = valid_end; + if (causal) { + const int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; + if (end_causal < end) { + end = end_causal; + } + } + + is_valid = (threadIdx.x >= valid_start && threadIdx.x < end); + + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + float input_data = attn_bias == nullptr + ? static_cast(input[index]) + : static_cast(input[index] + attn_bias[index]); + float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F); + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_exp(0.f); + if (is_valid) { + thread_data_exp = expf(input_data - max_block); + } + + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), end); + + // Store value of 1.0/sum. + if (threadIdx.x == 0) { + sum_reverse_block = (1.f) / sum; + } + __syncthreads(); + + // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. + if (threadIdx.x < all_sequence_length) { + output[index] = is_valid ? T(thread_data_exp * sum_reverse_block) : T(0.f); + } +} + +// Note about the attention_mask_strides and attention_mask/key_padding_mask +// attention_mask accepts 2D, 3D or 4D tensor, but it will be viewed as 3D tensor uniformally and it will be indexed +// as [batch_index, sequence_index, token_index]. +template +__global__ void SoftmaxWithRawMaskSmallKernel( + const int all_sequence_length, + const int sequence_length, + const int3 attention_mask_strides, + const int* attention_mask, // 2D, 3D or 4D attention mask + const bool* key_padding_mask, + const T* attn_bias, + const T* input, + T* output, + const bool causal, + const float rsqrt_head_size, + const bool skip_softmax, + const float mask_filter_value) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; + int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; + + // Mask all thread_data values to negative infinity to allow BlockReduce Max operation over all thread_data + // members with all invalid members set to a value that does not impact the final result. This is necessary + // to avoid the performance impact from using the valid_items interface. + float thread_data = -ROCMRT_INF_F; + if (threadIdx.x < all_sequence_length) { + thread_data = float(input[index]) * rsqrt_head_size; + + const int sequence_index = blockIdx.x % sequence_length; + if (causal) { + int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. + if (threadIdx.x > from_index) { + thread_data = -ROCMRT_INF_F; + } + } + + const int batch_index = blockIdx.y; + int mask_offset = attention_mask_strides.x * batch_index + + attention_mask_strides.y * sequence_index + + attention_mask_strides.z * threadIdx.x; + + if (nullptr == key_padding_mask) { + const int& mask = attention_mask[mask_offset]; + if (mask == 0) + thread_data += mask_filter_value; + } else { + const bool mask = key_padding_mask[mask_offset]; + if (mask) { + thread_data = -ROCMRT_INF_F; + } + } + + if (attn_bias != nullptr) { + thread_data += float(attn_bias[index]); + } + } + + if (skip_softmax) { + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data); + } + return; + } + + const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max()); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + // Mask all thread_data_exp values to zero to allow BlockReduce Sum operation over all thread_data_exp + // members with all invalid members set to a value that does not impact the final result. This is necessary + // to avoid the performance impact from using the valid_items interface. + float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum()); + + // Store value of 1.0/sum + if (threadIdx.x == 0) { + sum_reverse_block = (1.f) / sum; + } + __syncthreads(); + + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data_exp * sum_reverse_block); + } +} + +template +__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, + const T* attn_bias, const T* input, T* output, bool causal) { + SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, + attn_bias, input, output, causal); +} + +template +__global__ void SoftmaxKernel(const int all_sequence_length, const T* attn_bias, const T* input, T* output) { + Softmax(all_sequence_length, all_sequence_length, 0, attn_bias, input, output); +} + +template +Status ComputeSoftmax( + hipStream_t stream, + const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const T* attn_bias, const T* input, T* output, bool causal) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + if (all_sequence_length <= 32) { + const int blockSize = 32; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (all_sequence_length <= 128) { + const int blockSize = 128; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (all_sequence_length <= 256) { + const int blockSize = 256; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, attn_bias, input, output, causal); + } else if (!causal) { + const int blockSize = 1024; + SoftmaxKernel<<>>( + all_sequence_length, attn_bias, input, output); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); + } + + return HIP_CALL(hipPeekAtLastError()); +} + +template +__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, + const int* mask_end, const int* mask_start, + const T* attn_bias, const T* input, T* output, + bool causal) { + __shared__ int start_position; + __shared__ int end_position; + + if (threadIdx.x == 0) { + const int batch = blockIdx.y; + start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; + end_position = min(all_sequence_length, mask_end[batch]); + + // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. + if (start_position >= end_position) { + start_position = 0; + end_position = all_sequence_length; + } + } + __syncthreads(); + + SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, + attn_bias, input, output, causal); +} + +template +__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int* mask_end, const int* mask_start, + const T* attn_bias, const T* input, T* output) { + __shared__ int start_position; + __shared__ int end_position; + + if (threadIdx.x == 0) { + const int batch = blockIdx.y; + start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; + end_position = min(all_sequence_length, mask_end[batch]); + + // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. + if (start_position >= end_position) { + start_position = 0; + end_position = all_sequence_length; + } + } + __syncthreads(); + + Softmax(all_sequence_length, end_position, start_position, attn_bias, input, output); +} + +template +Status ComputeSoftmaxWithMask1D( + hipStream_t stream, + const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const int* mask_index, const int* mask_start, + const T* attn_bias, const T* input, T* output, const bool causal) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + +#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ + MaskedSoftmaxKernelSmall<<>>( \ + all_sequence_length, sequence_length, mask_index, mask_start, \ + attn_bias, input, output, causal); + + if (all_sequence_length <= 32) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); + } else if (all_sequence_length <= 64) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); + } else if (all_sequence_length <= 128) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); + } else if (all_sequence_length <= 256) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); + } else if (all_sequence_length <= 512) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); + } else if (all_sequence_length <= 1024) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); + } else if (!causal) { + const int blockSize = 1024; + MaskedSoftmaxKernel<<>>( + all_sequence_length, mask_index, mask_start, + attn_bias, input, output); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); + } + +#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE + + return HIP_CALL(hipPeekAtLastError()); +} + +template +Status ComputeSoftmaxWithRawMask(Stream* ort_stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int num_heads, + const int3 attention_mask_strides, + const int* attention_mask, + const bool* key_padding_mask, + const T* attn_bias, + const T* input, + T* output, + const bool causal, + const float rsqrt_head_size, + const bool use_persistent_softmax, + T* persistent_softmax_workspace, + const float mask_filter_value) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + + T* out = use_persistent_softmax ? persistent_softmax_workspace : output; + auto stream = static_cast(ort_stream->GetHandle()); + +#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ + SoftmaxWithRawMaskSmallKernel<<>>( \ + all_sequence_length, sequence_length, attention_mask_strides, \ + attention_mask, key_padding_mask, attn_bias, input, out, \ + causal, rsqrt_head_size, \ + use_persistent_softmax, mask_filter_value); + + if (all_sequence_length <= 32) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); + } else if (all_sequence_length <= 64) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); + } else if (all_sequence_length <= 128) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); + } else if (all_sequence_length <= 256) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); + } else if (all_sequence_length <= 512) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); + } else if (all_sequence_length <= 1024) { + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); + } + +#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE + + if (use_persistent_softmax) { + return dispatch_warpwise_softmax_forward(ort_stream, + output, + persistent_softmax_workspace, + all_sequence_length, + all_sequence_length, + batch_size * num_heads * sequence_length); + } + + return HIP_CALL(hipPeekAtLastError()); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh new file mode 100644 index 0000000000000..213940f132963 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/rocm_kernel.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cpu/bert/attention_parameters.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +namespace blas = onnxruntime::rocm::tunable::blas; + +namespace { +std::tuple GetQkvProjectGemmMNKBatch(const AttentionParameters* attention) { + int m = attention->sequence_length; + int n = (attention->hidden_size + attention->hidden_size + attention->v_hidden_size); // q + k + v + int k = attention->input_hidden_size; + int batch = attention->batch_size; + return {m, n, k, batch}; +} +} // namespace + +template +struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams { + std::string Signature() const override { + auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(attention); + return MakeString("M", m, "_N", n, "_K", k, "_B", batch); + } + + hipblasHandle_t handle; + const AttentionParameters* attention; + const hipDeviceProp_t* device_prop; + + const T* input_buffer; + const T* weight_buffer; + const T* bias_buffer; + T* out_buffer; + + int3 bias_strides; + + const T* ones; // used for broadcasting bias if the underlying algorithm does not support strides + T* workspace_buffer; +}; + +template +struct GemmPermuteGenericPipeline { + inline static size_t GetOutputNumBytes(const AttentionParameters* attn) { + auto [m, n, _, batch] = GetQkvProjectGemmMNKBatch(attn); + return sizeof(T) * m * n * batch; + } + + inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) { + return GetOutputNumBytes(attn); + } + + inline static std::tuple GetGemmMNK(const GemmPermuteParams* params) { + auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(params->attention); + return {batch * m, n, k}; + } + + inline static std::tuple UnspliceOutputQKV(const GemmPermuteParams* params) { + auto* attn = params->attention; + int64_t batch = attn->batch_size * attn->num_heads; + int64_t num_elems_per_batch = attn->sequence_length * attn->head_size; + int64_t num_elems = batch * num_elems_per_batch; + auto q = params->out_buffer + 0 * num_elems; + auto k = params->out_buffer + 1 * num_elems; + auto v = params->out_buffer + 2 * num_elems; + return {q, k, v}; + } + + inline static Status BroadcastBias(const GemmPermuteParams* params) { + auto [m, n, k] = GetGemmMNK(params); + // Bias shape is (N), broadcast using B(M, N) = ones(M, 1) x bias(1, N). + // TODO: use custom kernel of expand to improve the performance. + return blas::row_major::Gemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + m, n, 1, + /*alpha=*/1.0f, + params->ones, 1, + params->bias_buffer, n, + /*beta=*/0.0f, + params->workspace_buffer, n); + } + + inline static Status Gemm(const GemmPermuteParams* params) { + auto [m, n, k] = GetGemmMNK(params); + // result(M, N) = input x weights + bias. + return blas::row_major::Gemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + m, n, k, + /*alpha=*/1.0f, + params->input_buffer, k, + params->weight_buffer, n, + /*beta=*/1.0f, + params->workspace_buffer, n); + } + + inline static Status Permute0213(const GemmPermuteParams* params) { + auto* attn = params->attention; + // input should be BxSx3xNxH => gemm_buffer: 3xBxNxSxH + return LaunchTransQkv( + params->StreamHandle(), 3, attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, + params->device_prop->maxThreadsPerBlock, false, params->workspace_buffer, params->out_buffer); + } + + static Status Run(const GemmPermuteParams* params) { + ORT_RETURN_IF_ERROR(BroadcastBias(params)); + ORT_RETURN_IF_ERROR(Gemm(params)); + ORT_RETURN_IF_ERROR(Permute0213(params)); + return Status::OK(); + } +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh new file mode 100644 index 0000000000000..be8508670e4b1 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifdef USE_COMPOSABLE_KERNEL +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecialization; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface +using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; + +template +using device_batched_gemm_softmax_gemm_permute_instances = + std::tuple< + // clang-format off + // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias| + // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar| + // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector| + // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, +#if ROCM_VERSION >= 50500 + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, +#endif + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + // Padded fallback kernel + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, DT, DT, DT, DT, D0sDT, ck::Tuple<>, AccDT, DT, PassThrough, PassThrough, D0Op, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + // clang-format on + >; + +struct PreSoftmaxAttentionScoreOp { + PreSoftmaxAttentionScoreOp(float scale) : scale_(scale) {} + + // non-biased, non-masked + __host__ __device__ void operator()(float& y, const float& x) const { + y = scale_ * x; + } + + // biased or converted masked + __host__ __device__ void operator()(float& y, const float& x, const F16& bias) const { + y = scale_ * x + ck::type_convert(bias); + } + + // biased and converted masked + __host__ __device__ void operator()(float& y, const float& x, const F16& bias, const F16& converted_mask) const { + y = scale_ * x + ck::type_convert(bias) + ck::type_convert(converted_mask); + } + + float scale_; +}; + +// Use this function to gat implementation +template +std::vector, + PassThrough, PassThrough, D0Op, PassThrough, PassThrough, + MaskingSpec>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances() { + return {}; +} + +// implemented in impl_{fp16,bf16}[_biased][_masked].cu +// fp16, non-biased, non-masked +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); + +// fp16, biased, non-masked +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); + +// fp16, biased, fp16 masked, basically, two bias +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); + +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, non-masked +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, fp16 masked, basically, two bias +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu new file mode 100644 index 0000000000000..2e32a6594d164 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using NonBiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskDisabled>{}); + + return instances; +} + +using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu new file mode 100644 index 0000000000000..91da8d9e1f9a8 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskDisabled>{}); + + return instances; +} + +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu new file mode 100644 index 0000000000000..b08123be18977 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using BiasedNonmasked = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskDisabled>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskDisabled>{}); + + return instances; +} + +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh new file mode 100644 index 0000000000000..226b89cfb2b86 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -0,0 +1,915 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +/* About Computing in these Pipelines + +B: batch size of Attention Op. NOTE: To be disambiguated with batch size of GEMMs +S: sequence length +T: total sequence length +N: num of heads +H: head dimension + +The following use qkv_format == Q_K_V_BNSH (mode == BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE) as a example: + +BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op + +In QKV projection (prior to this pipeline): + /-> Q [B,S,N*H] ->Reshape-> [B,S,N,H] ->Permute0213-> [B,N,S,H] +X --o--> K [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] + \-> V [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] + +pre_softmax_attn_scores = Q*K' = [B,N,S,H] * [BxNxTxH]' = [B,N,S,T] Batched GEMM1 +pre_softmax_attn_scores_masked = pre_softmax_attn_scores * scale +? bias +? mask Scale Add Bias, +? is optional +attn_scores = softmax(pre_softmax_attn_scores_masked) = [B,N,S,T] Softmax +scaled_multi_head_attn = attn_scores * V = [B,N,S,T] * [B,N,T,H] = [B,N,S,H] Batched GEMM2 + +Op outputs scaled_multi_head_attn: +[B,N,S,H] ->Permute0213-> [B,S,N,H] ->Reshape-> [B,S,N*H] + + +For the computing of pre_softmax_attn_scores +? mask +? bias: + +GemmSoftmaxGemmPermuteGenericPipeline handles it in specialized softmax. TODO: remove it! + +CK in GemmSoftmaxGemmPermuteTunablePipeline + + Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked + bias --------------> [B,N,S,T] --+?--/ +mask_2d ---> [B,T] ---> [B,1,1,T] -/ + + Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked + bias --------------> [B,N,S,T] --+?--/ +mask_3d --> [B,S,T] --> [B,1,S,T] -/ + + Q*K' ---> scale ---> [B,N,S,T] -------+?--> masked + bias --------------> [B,N,S,T] --+?--/ +mask_4d -> [B,1,M,M] -> [B,1,S,T] -/ M is max_sequence_length from megatron, we will create a + **sub-view** from original mask buffer + +For CK implementation, there will be four cases combined: +non-biased, non-masked, no special processing. + biased, non-masked, no special processing, add the mask directly. +non-biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with scaled Q*K'. + biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast add with bias and scaled Q*K'. + +Broadcast add is not actually perform the broadcasting, just broadcast the load operation from memory. The impl details +are in composable kernels. The scale and add logic is performed via Acc0ElementOp + +# Classified modes: + +| Q | K | V | past(K)| pastV | present(K)| presentV | Op, desc +| ---- | ---- | ---- | ------ | ----- | --------- | -------- | --------- +| QFMT | KFMT | VFMT | - | - | - | - | A, basic, qkv is impl dependent by qkv_format +| QFMT | KFMT | VFMT | 2BNPH | - | 2BNTH *^ | - | A, past_present_share_buffer = false, qkv is impl dependent by qkv_format +| QFMT | KFMT | VFMT | 2BNMH | - | 2BNMH *^ | - | A, past_present_share_buffer = true, qkv is impl dependent by qkv_format +| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic +| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true +| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false +| BSNH | BLNH | BLNH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false +| BSNH | BNLH | BNLH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false +| BSNH | BLNH | BLNH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true +| BSNH | BNLH | BNLH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true +| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false +| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false +| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true +| BSNH | BNLH | BNLH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true +| BLN3H*^| - | - | - | - | - | - | MHA basic, qkv_packed +| BSNH | BLN2H*^| - | - | - | - | - | MHA basic, kv_packed + +Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel + +About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_v for v_buffer + +- Marked with `*` indicate the Tensor is used for k_buffer passing. +- Marked with `^` indicate the Tensor is used for v_buffer passing. + +# Supported Op + +- A: Attention +- MHA: MultiHeadAttention + +# Dim Value + +- B: batch_size +- N: num_heads +- H: head_size + +- S: sequence_length +- L: kv_sequence_length +- P: past_sequence_length +- T: total_sequence_length = P + L +- M: max_sequence_length +*/ + +#include "core/framework/tensor_shape.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/attention_softmax.h" +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh" +#include "core/providers/rocm/composable_kernel_common.h" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#endif // USE_COMPOSABLE_KERNEL + +#include +#include + +namespace blas = onnxruntime::rocm::tunable::blas; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +inline int3 Get2DMaskStrides(int total_sequence_length) { + // stride == 0 indicate broadcasting + return {total_sequence_length, 0, 1}; +} + +// A stride maps from natural coordinate to physical offset of underlying memory storage buffer offset. We need to +// specify both of the natural coordinate order, say (b,n,s,h), (b,s,n,h) or (b,n,h,s), and memory order, say BNSH or +// BSNH, to determain the strides. To obtain the offset, we just do the inner product of coordinate with the strides. +// This wrapper create the stride vector from the physical dimension (or physical shape). +struct Strides { + // Create the strides for BNSH physically indexed memory buffer + static Strides BNSHMemory(int batch_dim, + int num_head_dim, + int seqlen_dim, + int head_size_dim) { + ORT_UNUSED_PARAMETER(batch_dim); + return Strides{LongLong4{ + static_cast(num_head_dim) * seqlen_dim * head_size_dim, + static_cast(seqlen_dim) * head_size_dim, + static_cast(head_size_dim), + static_cast(1), + }}; + } + + // Create the strides for BSNH physically indexed memory buffer + static Strides BSNHMemory(int batch_dim, + int seqlen_dim, + int num_head_dim, + int head_size_dim) { + ORT_UNUSED_PARAMETER(batch_dim); + return Strides{LongLong4{ + static_cast(seqlen_dim) * num_head_dim * head_size_dim, + static_cast(head_size_dim), + static_cast(num_head_dim) * head_size_dim, + static_cast(1), + }}; + } + + template + T ForBNSHCoord() const { + using E = typename T::value_type; + return T{static_cast(strides_for_bnsh_coord.x), + static_cast(strides_for_bnsh_coord.y), + static_cast(strides_for_bnsh_coord.z), + static_cast(strides_for_bnsh_coord.w)}; + } + + template + T ForBSNHCoord() const { + using E = typename T::value_type; + return T{static_cast(strides_for_bnsh_coord.x), + static_cast(strides_for_bnsh_coord.z), + static_cast(strides_for_bnsh_coord.y), + static_cast(strides_for_bnsh_coord.w)}; + } + + template + T ForBNHSCoord() const { + using E = typename T::value_type; + return T{static_cast(strides_for_bnsh_coord.x), + static_cast(strides_for_bnsh_coord.y), + static_cast(strides_for_bnsh_coord.w), + static_cast(strides_for_bnsh_coord.z)}; + } + + int64_t OffsetAt(int b, int n, int s, int h) const { + return strides_for_bnsh_coord.x * b + strides_for_bnsh_coord.y * n + + strides_for_bnsh_coord.z * s + strides_for_bnsh_coord.w * h; + } + + // store intermediate strides in the canonical (b,n,s,h) coordinate order + LongLong4 strides_for_bnsh_coord; +}; + +template +std::tuple ConvertToOffsetedBufferViews( + const RocmAttentionParameters* attn, + const T* query = nullptr, // q or packed_qkv + const T* key = nullptr, // k or packed kv + const T* value = nullptr, // + const T* present = nullptr, // present or present_k + const T* present_v = nullptr) { + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: { + return {reinterpret_cast(query), + reinterpret_cast(key), + reinterpret_cast(value)}; + } + case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: + case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: { + auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->total_sequence_length * + attn->head_size; + return {reinterpret_cast(query), + reinterpret_cast(present), + reinterpret_cast(present) + offset}; + } + case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: + case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: { + auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->max_sequence_length * + attn->head_size; + return {reinterpret_cast(query), + reinterpret_cast(present), + reinterpret_cast(present) + offset}; + } + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + return {reinterpret_cast(query), + reinterpret_cast(present), + reinterpret_cast(present_v)}; + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: { + auto packed_kv = reinterpret_cast(key); + return {reinterpret_cast(query), packed_kv, packed_kv + attn->head_size}; + } + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: { + auto packed_qkv = reinterpret_cast(query); + return {packed_qkv, packed_qkv + 1 * attn->head_size, packed_qkv + 2 * attn->head_size}; + } + default: + ORT_ENFORCE("unreachable"); + return {}; + } +} + +inline std::tuple GetQkvStrides(const RocmAttentionParameters* attn) { + // G0 not used, because it is the slowest dimension + const int& B = attn->batch_size; + const int& N = attn->num_heads; + const int& S = attn->sequence_length; + const int& L = attn->kv_sequence_length; + const int& T = attn->total_sequence_length; + const int& M = attn->max_sequence_length; + const int& H = attn->head_size; + const int& Hv = attn->v_head_size; + + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + if (attn->qkv_format == Q_K_V_BNSH) { + return { + Strides::BNSHMemory(B, N, S, H), + Strides::BNSHMemory(B, N, L, H), + Strides::BNSHMemory(B, N, L, Hv), + }; + } else if (attn->qkv_format == Q_K_V_BSNH) { + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BSNHMemory(B, L, N, H), + Strides::BSNHMemory(B, L, N, Hv), + }; + } + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BNSHMemory(B, N, T, H), + Strides::BNSHMemory(B, N, T, Hv), + }; + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BNSHMemory(B, N, M, H), + Strides::BNSHMemory(B, N, M, Hv), + }; + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BSNHMemory(B, L, N, H), + Strides::BSNHMemory(B, L, N, Hv), + }; + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BNSHMemory(B, N, L, H), + Strides::BNSHMemory(B, N, L, Hv), + }; + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BSNHMemory(B, L, N, 2 * H), + Strides::BSNHMemory(B, L, N, 2 * Hv), + }; + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, L, N, 3 * H), + Strides::BSNHMemory(B, L, N, 3 * H), + Strides::BSNHMemory(B, L, N, 3 * Hv), + }; + default: + ORT_ENFORCE("unreachable"); + return {}; + } +} + +inline std::tuple GetRawMaskBufferAddrSizesAndStrides( + const int* buffer, const RocmAttentionParameters* attn) { + const int* offseted_buffer{buffer}; // how to view the mask buffer + int3 sizes{0, 0, 0}; // the logical shape of the view + int3 strides{-1, -1, -1}; // the physical memory layout + switch (attn->mask_type) { + case MASK_NONE: + case MASK_2D_DUMMY: + break; // No mask + case MASK_2D_KEY_PADDING: + sizes = {attn->batch_size, 1, attn->total_sequence_length}; + strides = Get2DMaskStrides(attn->total_sequence_length); + break; + case MASK_3D_ATTENTION: + sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; + strides = {attn->sequence_length * attn->total_sequence_length, attn->total_sequence_length, 1}; + break; + case MASK_4D_MEGATRON: + // offset to skip past sequence part, so that we can index it with [batch_index, sequence_index, token_index] + offseted_buffer = buffer + attn->past_sequence_length * attn->max_sequence_length; + sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; + strides = {attn->max_sequence_length * attn->max_sequence_length, attn->max_sequence_length, 1}; + break; + default: + LOGS_DEFAULT(FATAL) << "unsupported mask type: " << attn->mask_type; + throw std::runtime_error("unsupported mask type"); + } + return {offseted_buffer, sizes, strides}; +} + +template +struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { + std::string Signature() const override { + return MakeString( + "B", attention->batch_size, + "_S", attention->sequence_length, + "_T", attention->total_sequence_length, + "_N", attention->num_heads, + "_H", attention->head_size, + "_Hv", attention->v_head_size, + bias_buffer != nullptr ? "_B" : "_NB", + "_M", mask_index_dims.size(), + "_QKV", attention->qkv_format, + "_MODE", attention->mode); + } + + std::tuple GetGemmsMNKOBatch() const { + ORT_ENFORCE(attention != nullptr); + auto m = attention->sequence_length; + auto n = attention->total_sequence_length; + auto k = attention->head_size; + auto o = attention->v_head_size; + auto batch = attention->batch_size * attention->num_heads; + return {m, n, k, o, batch}; + } + + hipblasHandle_t handle; + const RocmAttentionParameters* attention; + const hipDeviceProp_t* device_prop; + + float scale; + const T* q_buffer; + const T* k_buffer; + const T* v_buffer; + T* out_buffer; + + // optional, attention bias [B,N,S,T] + // TODO: support shape [B,1,S,T], [1, N, S, T], [1, 1, S, T] with broadcast. + const T* bias_buffer{nullptr}; + + // optional, mask value + const int* mask_index_buffer{nullptr}; + TensorShapeVector mask_index_dims{}; + + // optional, internal + void* workspace_buffer{nullptr}; +}; + +inline bool IsKVBNMH(AttentionMode mode) { + switch (mode) { + case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: + case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + return true; + default: + return false; + } +} + +template +struct GemmSoftmaxGemmPermuteGenericPipeline { + static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams* params) { + return params->mask_index_buffer != nullptr && params->mask_index_dims.size() >= 2; + } + + static std::tuple GetWorkspacePlan(const GemmSoftmaxGemmPermuteParams* params) { + auto bytes = GetAttentionScratchSize( + sizeof(T), + params->attention->batch_size, + params->attention->num_heads, + params->attention->sequence_length, + params->attention->total_sequence_length); + auto gemm1_out = reinterpret_cast(params->workspace_buffer); + auto softmax_out = gemm1_out + (bytes / sizeof(T)); + auto gemm2_out = softmax_out + (bytes / sizeof(T)); + return {gemm1_out, softmax_out, gemm2_out}; + } + + inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { + return GetAttentionWorkspaceSize( + sizeof(T), + attn->batch_size, + attn->num_heads, + attn->head_size, + attn->sequence_length, + attn->total_sequence_length); + } + + inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams* params) { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + + int k_buffer_stride = n * k; + if (IsKVBNMH(params->attention->mode)) { + k_buffer_stride = params->attention->max_sequence_length * params->attention->head_size; + } + + // GEMM1 [m,k] * [n,k]' -> [m,n] + return blas::row_major::StridedBatchedGemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::Trans, + m, n, k, + // For raw attention mask, the scalar is moved to softmax computation. + /*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale, + params->q_buffer, k, m * k, + params->k_buffer, k, k_buffer_stride, + /*beta=*/0.0f, + gemm1_out, n, m * n, + batch); + } + + inline static Status SoftmaxRawMask(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { + // Softmax on [m,n] along the n dimension. + // Raw attention mask could be 2D (B,S) or 3D (B,S,T) or 4D(B,1,M,M), where M is the max sequence length. + auto attn = params->attention; + auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + T* persistent_softmax_workspace = gemm1_out; // replace Q*K' in place if persistent softmax is selected. + return ComputeSoftmaxWithRawMask( + params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, + strides, buffer, nullptr, params->bias_buffer, gemm1_out, softmax_out, + attn->is_unidirectional, /* FIXME: this must not be attn.scale! */ params->scale, + use_persistent_softmax, persistent_softmax_workspace, attn->mask_filter_value); + } + + inline static Status Softmax1DIndexMask(const GemmSoftmaxGemmPermuteParams* params) { + auto mask_1d = params->mask_index_buffer; + auto mask_1d_size = params->mask_index_dims[0]; + // Softmax on [m,n] along the n dimension. + // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. + auto attn = params->attention; + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + const int* mask_start = (mask_1d_size > attn->batch_size) ? mask_1d + attn->batch_size : nullptr; + return ComputeSoftmaxWithMask1D( + params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, + mask_1d, mask_start, params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); + } + + inline static Status SoftmaxNoMask(const GemmSoftmaxGemmPermuteParams* params) { + // Softmax on [m,n] along the n dimension. + auto attn = params->attention; + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + return ComputeSoftmax( + params->StreamHandle(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, + params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); + } + + inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams* params) { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + + int v_buffer_stride = n * o; + if (IsKVBNMH(params->attention->mode)) { + v_buffer_stride = params->attention->max_sequence_length * params->attention->v_head_size; + } + + // GEMM2 [m,n] * [n,o] -> [m,o] + // semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H. + return blas::row_major::StridedBatchedGemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + m, o, n, + /*alpha=*/1.0f, + softmax_out, n, m * n, + params->v_buffer, o, v_buffer_stride, + /*beta=*/0.0f, + gemm2_out, o, m * o, + batch); + } + + inline static Status Permute0213(const GemmSoftmaxGemmPermuteParams* params) { + // Permute 0213 + // gemm2_out is B,N,S,H, transpose to out_buffer as B,S,N,H + auto attn = params->attention; + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + return LaunchTransCtx( + params->StreamHandle(), + attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, + params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer); + } + + static Status GetSupportedStatus(const GemmSoftmaxGemmPermuteParams* params) { + const auto& attn = params->attention; + // TODO: address the BNMH k,v strides + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: + case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: + case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: + case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: + if (attn->qkv_format == Q_K_V_BNSH) { + return Status::OK(); + } else { + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, got ", + attn->qkv_format); + } + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH but k, v are BLNH"); + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + // If sequence_length is 1, query of B1NH can be simply viewed as BN1H. + if (attn->sequence_length == 1) { + return Status::OK(); + } else { + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, ", + "only if sequence_length is 1, query of BSNH can be viewed as BNSH"); + } + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH"); + default: + return TUNABLE_OP_UNSUPPORTED("unknonw"); + } + return TUNABLE_OP_UNSUPPORTED("unknonw case"); + } + + static Status Run(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { + auto supported_status = GetSupportedStatus(params); + if (!supported_status.IsOK()) { + return supported_status; + } + ORT_RETURN_IF_ERROR(Gemm1(params)); + + if (UseRawAttentionMask(params)) { + ORT_RETURN_IF_ERROR(SoftmaxRawMask(params, use_persistent_softmax)); + } else if (params->mask_index_dims.size() == 1) { // 1d index mask + ORT_RETURN_IF_ERROR(Softmax1DIndexMask(params)); + } else { + ORT_RETURN_IF_ERROR(SoftmaxNoMask(params)); + } + + ORT_RETURN_IF_ERROR(Gemm2(params)); + ORT_RETURN_IF_ERROR(Permute0213(params)); + return Status::OK(); + } +}; + +template +class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp> { + public: + GemmSoftmaxGemmPermuteTunableOp(); + + inline static bool IsSupportedMode(const RocmAttentionParameters* attn) { + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: + // depends on qkv format + if (attn->qkv_format == Q_K_V_BNSH || attn->qkv_format == Q_K_V_BSNH) { + return true; + } else { + return false; + } + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: + return true; + default: + return false; + } + } + + inline static bool IsSupportedMaskType(const RocmAttentionParameters* attn) { + switch (attn->mask_type) { + case MASK_NONE: + case MASK_2D_DUMMY: + case MASK_2D_KEY_PADDING: + case MASK_3D_ATTENTION: + case MASK_4D_MEGATRON: + return true; + default: + return false; + } + } + + inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { + size_t num_bytes = GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(attn); + +#ifdef USE_COMPOSABLE_KERNEL + if (IsSupportedMaskType(attn)) { + auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn); + num_bytes = std::max(num_bytes, sizeof(T) * sizes.x * sizes.y * sizes.z); + } +#endif + + return num_bytes; + } + + template + __global__ static void ConvertToFilledMaskValue( + T* __restrict__ out, + const int3 out_strides, + const int* __restrict__ mask_buffer, + const int3 mask_lengths, // [B,S,T] + const int3 mask_strides, + Converter cvt) { + const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; + if (global_idx >= mask_lengths.x * mask_lengths.y * CeilDiv(mask_lengths.z, VecSize)) { + return; + } + + const int tidx = (global_idx % CeilDiv(mask_lengths.z, VecSize)) * VecSize; + const int bs_idx = global_idx / CeilDiv(mask_lengths.z, VecSize); + const int sidx = bs_idx % mask_lengths.y; + const int bidx = bs_idx / mask_lengths.y; + + int64_t in_offset = mask_strides.x * bidx + mask_strides.y * sidx + mask_strides.z * tidx; + int64_t out_offset = out_strides.x * bidx + out_strides.y * sidx + out_strides.z * tidx; + + if (tidx + VecSize <= mask_lengths.z) { + using LoadT = const aligned_vector; + using StoreT = aligned_vector; + LoadT load = *reinterpret_cast(mask_buffer + in_offset); + StoreT store; + +#pragma unroll + for (int i = 0; i < VecSize; i++) { + store.val[i] = cvt(load.val[i]); + } + *reinterpret_cast(out + out_offset) = store; + } else { +#pragma unroll + for (int i = 0; i < mask_lengths.z - tidx; i++) { + out[out_offset + i] = cvt(mask_buffer[in_offset + i]); + } + } + } + + static Status LaunchConvertToFilledMaskValue(const GemmSoftmaxGemmPermuteParams* params) { + constexpr const int kThreadPerBlock = 256; + constexpr const int kVecSize = 4; + + auto attn = params->attention; + auto [buffer, lengths, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); + int64_t total_threads = lengths.x * lengths.y * CeilDiv(lengths.z, kVecSize); + auto num_blocks = CeilDiv(total_threads, kThreadPerBlock); + + auto mask_filter_value = attn->mask_filter_value; + auto cvt = [=] __device__(int v) -> T { + return v == 1 ? 0 : mask_filter_value; + }; + + ConvertToFilledMaskValue<<StreamHandle()>>>( + reinterpret_cast(params->workspace_buffer), {lengths.y * lengths.z, lengths.z, 1}, // out desc + buffer, lengths, strides, // mask desc + cvt); + + return HIP_CALL(hipGetLastError()); + } +}; + +#ifdef USE_COMPOSABLE_KERNEL + +template +auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { + constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); + + using Nop = ck::tensor_operation::element_wise::PassThrough; + using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), + "attention mode is not supported, got ", params->attention->mode); + if constexpr (USE_BIAS) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer == nullptr, "biased version only support input with bias"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer != nullptr, "non-biased version only support input without bias"); + } + if constexpr (USE_MASK) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), + "mask type is not supported, got ", params->attention->mask_type); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer == nullptr, "masked version only support input with mask"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); + } + + auto attn = params->attention; + const int& G0 = attn->batch_size; + const int& G1 = attn->num_heads; + const int& M = attn->sequence_length; + const int& N = attn->total_sequence_length; + const int& K = attn->head_size; + const int& O = attn->v_head_size; + { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); + } + + auto [qs, ks, vs] = GetQkvStrides(attn); + std::vector q_buffer_lengths = {G0, G1, M, K}; + std::vector q_buffer_strides = qs.template ForBNSHCoord>(); + std::vector k_buffer_lengths = {G0, G1, N, K}; + std::vector k_buffer_strides = ks.template ForBNSHCoord>(); + std::vector v_buffer_lengths = {G0, G1, O, N}; + std::vector v_buffer_strides = vs.template ForBNHSCoord>(); + std::vector out_buffer_lengths = {G0, G1, M, O}; + std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 + + std::array bias_buffers{}; + std::array, kNumBiasBuffer> bias_lengths{}; + std::array, kNumBiasBuffer> bias_strides{}; + if constexpr (USE_BIAS) { + bias_buffers[0] = const_cast(params->bias_buffer); + bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + bias_strides[0] = {G1 * M * N, M * N, N, 1}; + } + if constexpr (USE_MASK) { + bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; + bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + if (params->mask_index_dims.size() == 2) { // [B,T] + bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; + } else if (params->mask_index_dims.size() == 3) { // [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else { + ORT_ENFORCE(false, "Unreachable"); + } + } + + auto arg = impl->MakeArgumentPointer( + params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, + bias_buffers, // Gemm1 bias, as attention mask + {}, // Gemm2 bias + q_buffer_lengths, q_buffer_strides, + k_buffer_lengths, k_buffer_strides, + v_buffer_lengths, v_buffer_strides, + out_buffer_lengths, out_buffer_strides, + bias_lengths, bias_strides, + {}, + {}, + Nop{}, + Nop{}, + Acc0ElementOp{params->scale}, + Nop{}, + Nop{}); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + + if constexpr (USE_MASK) { + ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); + } + + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); +} + +template +auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { + using CKDataType = typename CKDataTypeAdaptor::type; + using D0DataType = typename ck::detail::tuple_concat< + std::conditional_t, ck::Tuple<>>, + std::conditional_t, ck::Tuple<>>>::type; + + constexpr static auto MaskingSpecMaskDisabled = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + constexpr static auto MaskingSpecMaskOutUpperTriangle = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + + std::vector>>> + ret; + + for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { + auto type_string = impl->GetTypeString(); + + auto invoker = impl->MakeInvokerPointer(); + auto op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GemmSoftmaxGemmPermuteParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); + + return GetArgAndRunInvoker(impl, invoker, params); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); + } + + for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { + auto type_string = impl->GetTypeString(); + + auto invoker = impl->MakeInvokerPointer(); + auto op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GemmSoftmaxGemmPermuteParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->attention->sequence_length != params->attention->total_sequence_length, + "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); + + return GetArgAndRunInvoker(impl, invoker, params); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); + } + + return ret; +} +#endif // USE_COMPOSABLE_KERNEL + +template +GemmSoftmaxGemmPermuteTunableOp::GemmSoftmaxGemmPermuteTunableOp() { + this->RegisterOp([](const GemmSoftmaxGemmPermuteParams* params) { + return GemmSoftmaxGemmPermuteGenericPipeline::Run(params, false); + }); + +#ifdef USE_COMPOSABLE_KERNEL + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } +#endif +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h new file mode 100644 index 0000000000000..0aff519d20e99 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +Status LaunchDecoderAttentionKernel( + const hipDeviceProp_t& prop, // Device Properties + RocmTuningContext* tuning_ctx, // context for tuning + Stream* stream, // ORT Stream + hipblasHandle_t& hipblas, // hipblas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden layer size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise.h b/onnxruntime/contrib_ops/rocm/bert/elementwise.h new file mode 100644 index 0000000000000..768295767835a --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +Status LaunchElementwiseKernel(RocmTuningContext* tuning_ctx, Stream* stream, + const T* input, int input_length, + const T* bias, int bias_length, + T* output); + +// The following is LaunchElementwiseKernel implementation detail. Their interfaces are exposed for kernel explorer. +namespace internal { + +template +struct ElementwiseParams : OpParams { + ElementwiseParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, + const T* input, const T* bias, T* output, int input_length, int bias_length) + : OpParams(tuning_ctx, stream), + input(input), + bias(bias), + output(output), + input_length(input_length), + bias_length(bias_length) {} + + std::string Signature() const override { + std::string sig = std::to_string(input_length) + "_" + std::to_string(bias_length); + return sig; + } + + const T* input; + const T* bias; + T* output; + int input_length; + int bias_length; +}; + +template +class ElementwiseOp { + public: + Status operator()(const ElementwiseParams* params); + Status IsSupported(const ElementwiseParams* params); +}; + +template +Status ElementwiseStaticSelection(const ElementwiseParams* params); + +template +class ElementwiseTunableOp : public TunableOp> { + public: + ElementwiseTunableOp(); +}; + +} // namespace internal + +#define ELEMENTWISE_FWD_DECL(FnName, T) \ + namespace functor { \ + struct FnName; \ + } + +ELEMENTWISE_FWD_DECL(FastGeLU, float); +ELEMENTWISE_FWD_DECL(FastGeLU, double); +ELEMENTWISE_FWD_DECL(FastGeLU, half); +ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16); + +ELEMENTWISE_FWD_DECL(GeLU, float); +ELEMENTWISE_FWD_DECL(GeLU, double); +ELEMENTWISE_FWD_DECL(GeLU, half); +ELEMENTWISE_FWD_DECL(GeLU, BFloat16); + +ELEMENTWISE_FWD_DECL(ReLU, float); +ELEMENTWISE_FWD_DECL(ReLU, half); +ELEMENTWISE_FWD_DECL(ReLU, BFloat16); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh new file mode 100644 index 0000000000000..8255e70d27e48 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl.cuh @@ -0,0 +1,256 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/tunable/util.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "contrib_ops/rocm/bert/elementwise.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +namespace functor { + +struct FastGeLU { + template + __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { + constexpr const float b = 0.7978845608028654f; // sqrt(2.0/M_PI) + + // const T cdf = a + a * _Tanh(in * (c * in * in + b)); + const T xb = x * T(b); + const T u = xb * T(0.044715f) * x * x + xb; + const T emu = __expf(-u - u); + const T cdf = T(1.0f) / (T(1.0f) + emu); + y = x * cdf; + } +}; + +struct GeLU { + template + __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { + y = T(0.5f) * x * (T(1.f) + T(erf(0.70710678118f * float(x)))); + } +}; + +struct ReLU { + template + __host__ __device__ __forceinline__ void operator()(T& y, const T& x) const { + y = x >= T{} ? x : T{}; + } +}; + +} // namespace functor + +using onnxruntime::rocm::CeilDiv; +using onnxruntime::rocm::GPU_WARP_SIZE; + +template +__global__ void ElementwiseKernel( + const T* __restrict__ input, int input_length, + const T* __restrict__ bias, int bias_length, + T* __restrict__ output) { + const int idx = blockIdx.x * TPB + threadIdx.x; + Fn f{}; + + if (idx < input_length) { + const T x = input[idx] + (bias == nullptr ? T{} : bias[idx % bias_length]); + f(output[idx], x); + } +} + +template +__global__ void ElementwiseKernelVec( + const T* __restrict__ input, int input_length, + const T* __restrict__ bias, int bias_length, + T* output) { + using VecT = onnxruntime::rocm::aligned_vector; + Fn f{}; + + const int idx = (blockIdx.x * TPB + threadIdx.x) * ILP; + if (idx < input_length) { + T input_v[ILP]; + VecT* input_val = reinterpret_cast(&input_v); + *input_val = *reinterpret_cast(&input[idx]); + T output_v[ILP]; + VecT* output_val = reinterpret_cast(&output_v); + T bias_v[ILP]; + if (bias != nullptr) { + VecT* bias_val = reinterpret_cast(&bias_v); + *bias_val = *reinterpret_cast(&bias[idx % bias_length]); + } + +#pragma unroll + for (int i = 0; i < ILP; i++) { + const T x = (bias == nullptr) ? input_v[i] : (T)(input_v[i] + bias_v[i]); + f(output_v[i], x); + } + *(reinterpret_cast(&output[idx])) = *output_val; + } +} + +template +Status LaunchElementwiseKernel( + RocmTuningContext* tuning_ctx, Stream* stream, + const T* input, int input_length, + const T* bias, int bias_length, + T* output) { + internal::ElementwiseParams params(tuning_ctx, stream, input, bias, output, input_length, bias_length); + if (tuning_ctx->IsTunableOpEnabled()) { + static internal::ElementwiseTunableOp op; + return op(¶ms); + } + + return internal::ElementwiseStaticSelection(¶ms); +} + +namespace internal { + +template +Status ElementwiseOp::operator()(const ElementwiseParams* params) { + dim3 blocks(CeilDiv(params->input_length, ThreadsPerBlock * VecSize)); + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, + params->bias, params->bias_length, + params->output); + return HIP_CALL(hipGetLastError()); +} + +template +Status ElementwiseOp::IsSupported(const ElementwiseParams* params) { + // TODO(anyone): Add tail handling for FastGelu + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !((params->bias_length > 0 && params->bias_length % VecSize == 0 && params->input_length % VecSize == 0) || + (params->bias_length == 0 && params->input_length % VecSize == 0))); + // Avoid redundant configurations + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->input_length > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize)); + + return Status::OK(); +} + +template +Status ElementwiseStaticSelection(const ElementwiseParams* params) { + constexpr int block_size = 256; + if constexpr (std::is_same_v) { + if (params->bias != nullptr) { + if (0 == (params->bias_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 + const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else if (0 == (params->bias_length % 4)) { + const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else if (0 == (params->bias_length % 2)) { + const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else { + const int grid_size = (params->input_length + block_size - 1) / block_size; + ElementwiseKernel<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } + } else { + if (0 == (params->input_length % 8) && (params->input_length >= 3145728)) { // 3145728=8*128*3072 + const int grid_size = (params->input_length / 8 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else if (0 == (params->input_length % 4)) { + const int grid_size = (params->input_length / 4 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else if (0 == (params->input_length % 2)) { + const int grid_size = (params->input_length / 2 + block_size - 1) / block_size; + ElementwiseKernelVec<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } else { + const int grid_size = (params->input_length + block_size - 1) / block_size; + ElementwiseKernel<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } + } + } else { + const int grid_size = (params->input_length + block_size - 1) / block_size; + ElementwiseKernel<<StreamHandle()>>>( + params->input, params->input_length, params->bias, params->bias_length, params->output); + } + return HIP_CALL(hipGetLastError()); +} + +template +ElementwiseTunableOp::ElementwiseTunableOp() { + this->RegisterOp(ElementwiseStaticSelection); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); + this->RegisterOp(ElementwiseOp{}); +} + +#undef ADD_OP + +} // namespace internal + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime + +#define ELEMENTWISE_KERNEL_IMPL(Fn, T) \ + namespace onnxruntime { \ + namespace contrib { \ + namespace rocm { \ + template Status LaunchElementwiseKernel( \ + RocmTuningContext * tuning_ctx, Stream* stream, \ + const T* input, int input_length, \ + const T* bias, int bias_length, \ + T* output); \ + namespace internal { \ + template class ElementwiseTunableOp; \ + } \ + } \ + } \ + } diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu new file mode 100644 index 0000000000000..c2a670ea76aca --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" + +ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float); +ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double); +ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half); +ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu new file mode 100644 index 0000000000000..97f0f74640c6e --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" + +ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double); +ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float); +ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half); +ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu new file mode 100644 index 0000000000000..67e50869133f5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_relu.cu @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" + +ELEMENTWISE_KERNEL_IMPL(functor::ReLU, float); +ELEMENTWISE_KERNEL_IMPL(functor::ReLU, half); +ELEMENTWISE_KERNEL_IMPL(functor::ReLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc new file mode 100644 index 0000000000000..fdb62d3a2aec5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/gemm_fast_gelu.h" + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/rocm/rocm_common.h" + +using onnxruntime::rocm::ToHipType; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GemmFastGelu, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + GemmFastGelu); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +template +Status GemmFastGelu::ComputeInternal(OpKernelContext* ctx) const { + typedef typename ToHipType::MappedType HipT; + + const auto* X = ctx->Input(0); + const auto* W = ctx->Input(1); + const auto* bias = ctx->Input(2); + + bool transa = false; + bool transb = false; + bool trans_batch_a = false; + bool trans_batch_b = false; + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(X->Shape(), W->Shape(), transa, transb, trans_batch_a, trans_batch_b, false)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) + return Status::OK(); + + // gemmfastgelu only support alpha == 1 and beta == 0 + const HipT alpha = ToHipType::FromFloat(1.0f); + const HipT beta = ToHipType::FromFloat(0.0f); + + using onnxruntime::rocm::tunable::blas::BlasOp; + + return blas::row_major::GemmFastGelu( + GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx), + transa ? BlasOp::Trans : BlasOp::NonTrans, + transb ? BlasOp::Trans : BlasOp::NonTrans, + helper.M(), helper.N(), helper.K(), + alpha, + reinterpret_cast(X->Data()), helper.Lda(transa), + reinterpret_cast(W->Data()), helper.Ldb(transb), + (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, + beta, + reinterpret_cast(Y->MutableData()), helper.Ldc()); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h new file mode 100644 index 0000000000000..ae4f84fa5f033 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using onnxruntime::rocm::RocmKernel; + +template +class GemmFastGelu final : public RocmKernel { + public: + GemmFastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {} + Status ComputeInternal(OpKernelContext* ctx) const override; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh new file mode 100644 index 0000000000000..77f53f9eed027 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#ifdef USE_COMPOSABLE_KERNEL +#include "core/providers/rocm/composable_kernel_common.h" + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#endif + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" + +using onnxruntime::rocm::ToHipType; + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { +namespace internal { + +#ifdef USE_COMPOSABLE_KERNEL + +using onnxruntime::rocm::CKBlasOpAdaptor; +using onnxruntime::rocm::CKDataTypeAdaptor; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Nop = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +template +auto GetCKGemmAddFastGeluTypeStringAndOps() { + using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; + using DeviceGemmAddFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, BLayout, ck::Tuple, Row, + CKDataType, CKDataType, ck::Tuple, CKDataType, + Nop, Nop, AddFastGelu>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = onnxruntime::MakeString("withbias ", impl->GetTypeString()); + + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { + auto one = ToHipType::FromFloat(1.0f); + auto zero = ToHipType::FromFloat(0.0f); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->alpha != one || params->beta != zero || params->bias == nullptr, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); + + auto nop = Nop{}; + auto addfastgelu = AddFastGelu{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, std::array{params->bias}, params->c, + params->m, params->n, params->k, + params->lda, params->ldb, std::array{0}, params->ldc, + nop, nop, addfastgelu); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); + } + return ret; +} + +template +auto GetCKGemmFastGeluTypeStringAndOps() { + using CKDataType = typename CKDataTypeAdaptor::type; + using ALayout = typename CKBlasOpAdaptor::type; + using BLayout = typename CKBlasOpAdaptor::type; + using DeviceGemmFastGelu = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, BLayout, ck::Tuple<>, Row, + CKDataType, CKDataType, ck::Tuple<>, CKDataType, + Nop, Nop, FastGelu>; + using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory; + + std::vector>>> ret; + for (auto&& impl : InstanceFactory::GetInstances()) { + auto type_string = onnxruntime::MakeString("nobias ", impl->GetTypeString()); + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemmfastgelu_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmFastGeluParams* params) -> Status { + auto one = ToHipType::FromFloat(1.0f); + auto zero = ToHipType::FromFloat(0.0f); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->alpha != one || params->beta != zero || params->bias != nullptr, + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); + + auto nop = Nop{}; + auto fastgelu = FastGelu{}; + auto arg = impl->MakeArgumentPointer(params->a, params->b, + {}, + params->c, + params->m, params->n, params->k, + params->lda, params->ldb, + {}, + params->ldc, + nop, nop, fastgelu); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemmfastgelu_op))); + } + return ret; +} +#else +struct Row {}; +struct Col {}; +#endif // USE_COMPOSABLE_KERNEL + +} // namespace internal +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h new file mode 100644 index 0000000000000..2b8a21b83f177 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_common.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/gemm_common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +using onnxruntime::rocm::tunable::blas::BlasOp; +using onnxruntime::rocm::tunable::blas::BlasOpToString; + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { + +template +struct GemmFastGeluParams : OpParams { + std::string Signature() const override { + bool has_bias = (nullptr != bias) ? 0 : 1; + return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k, '_', has_bias); + } + hipblasHandle_t handle; + BlasOp opa; + BlasOp opb; + int64_t m; + int64_t n; + int64_t k; + T alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + const T* bias; + T beta; + T* c; + int64_t ldc; +}; + +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu new file mode 100644 index 0000000000000..8d7e64b1015be --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.cu @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES +#include "contrib_ops/rocm/bert/gemm_fast_gelu_impl.h" + +#include +#include + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh" +#include "core/providers/rocm/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { + +namespace row_major { + +template +inline GEMMFASTGELU(T, ScalarT) { + GemmFastGeluParams params; + params.tuning_ctx = tuning_ctx; + params.stream = stream; + params.handle = handle; + + params.opa = opa; + params.opb = opb; + params.m = m; + params.n = n; + params.k = k; + if constexpr (!std::is_same_v && std::is_same_v) { + params.alpha = ToHipType::FromFloat(std::forward(alpha)); + } else { + params.alpha = alpha; + } + params.a = a; + params.lda = lda; + params.b = b; + params.ldb = ldb; + params.bias = bias; + if constexpr (!std::is_same_v && std::is_same_v) { + params.beta = ToHipType::FromFloat(std::forward(beta)); + } else { + params.beta = beta; + } + params.c = c; + params.ldc = ldc; + + if (tuning_ctx->IsTunableOpEnabled()) { + if (opa == BlasOp::N && opb == BlasOp::N) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + return gemm_fast_gelu(¶ms); + } else if (opa == BlasOp::T && opb == BlasOp::N) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + return gemm_fast_gelu(¶ms); + } else if (opa == BlasOp::N && opb == BlasOp::T) { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + return gemm_fast_gelu(¶ms); + } else /*if (opa == BlasOp::T && opb == BlasOp::T)*/ { + static internal::GemmFastGeluTunableOp gemm_fast_gelu{}; + return gemm_fast_gelu(¶ms); + } + } + + return internal::GemmFastGeluUnfused(¶ms); +} + +#define CALL_GEMMFASTGELU(T, ScalarT) \ + GemmFastGelu(tuning_ctx, stream, handle, \ + opa, opb, \ + m, n, k, \ + alpha, a, lda, b, ldb, bias, \ + beta, c, ldc) + +// clang-format off +GEMMFASTGELU(float, float ) { return CALL_GEMMFASTGELU(float, float ); } +GEMMFASTGELU(half, half ) { return CALL_GEMMFASTGELU(half, half ); } +GEMMFASTGELU(BFloat16, BFloat16) { return CALL_GEMMFASTGELU(BFloat16, BFloat16); } +GEMMFASTGELU(half, float ) { return CALL_GEMMFASTGELU(half, float ); } +GEMMFASTGELU(BFloat16, float ) { return CALL_GEMMFASTGELU(BFloat16, float ); } +// clang-format on + +#undef CALL_GEMMFASTGELU + +} // namespace row_major + +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h new file mode 100644 index 0000000000000..b707c63ef44be --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_impl.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "core/common/status.h" +#include "core/common/float16.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { + +#define GEMMFASTGELU(T, ScalarT) \ + common::Status GemmFastGelu( \ + RocmTuningContext* tuning_ctx, Stream* stream, hipblasHandle_t handle, \ + BlasOp opa, BlasOp opb, \ + std::int64_t m, std::int64_t n, std::int64_t k, \ + ScalarT alpha, const T* a, std::int64_t lda, const T* b, std::int64_t ldb, \ + const T* bias, ScalarT beta, T* c, std::int64_t ldc) + +namespace row_major { + +GEMMFASTGELU(float, float); +GEMMFASTGELU(half, half); +GEMMFASTGELU(BFloat16, BFloat16); +GEMMFASTGELU(half, float); +GEMMFASTGELU(BFloat16, float); + +} // namespace row_major + +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime + +#ifndef _GEMM_FASTGELU_H_KEEP_SIGNATURE_DEFINES +#undef GEMMFASTGELU +#endif diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh new file mode 100644 index 0000000000000..e157aa57f8c43 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_tunable.cuh @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "contrib_ops/rocm/bert/elementwise.h" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh" +#include "contrib_ops/rocm/bert/gemm_fast_gelu_common.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/gemm_hipblaslt.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace blas { +namespace internal { + +using namespace onnxruntime::rocm::tunable::blas::internal; + +template +Status GemmFastGeluUnfused(const GemmFastGeluParams* params) { + namespace column_major = onnxruntime::rocm::tunable::blas::column_major; + ORT_RETURN_IF_ERROR(column_major::Gemm(params->tuning_ctx, params->stream, params->handle, + params->opb, params->opa, + params->n, params->m, params->k, + params->alpha, params->b, params->ldb, params->a, params->lda, + params->beta, params->c, params->ldc)); + + int64_t fast_gelu_input_length = params->m * params->n; + int64_t bias_length = (params->bias != nullptr) ? params->n : 0; + + // Because of GemmFastGeluUnfused is a combination of GemmOp and FastGeluOp, FastGeluOp in this combination is + // an inplace computation. + // 1. If we call GemmFastGeluUnfused directly with enabled tuning, it may cause the input buffer of FastGelu been + // updated accumulatedly and result in incorrect result finally. This only happens if the tuning's FindFastest is invoked. + // 2. It's safe to call GemmFastGeluUnfused with disabled tuning, FastGelu only run once and produce correct result. + // 3. It's safe to call GemmFastGeluUnfused as part of GemmFastGeluTunableOp with enable tuning, GemmTunableOp and + // FastGeluTunableOp will do tune in first warmup step separately during GemmFastGeluUnfused profiling process. + // After that, the call to GemmFastGeluUnfused not invoke tuning's FindFastest of FastGelu. + // + // Note: If any change cause directly usage of GemmFastGeluUnfused, add PreTuning() and PostTuning() in FastGeluTunableOp + // to protect original input value. + return onnxruntime::contrib::rocm::LaunchElementwiseKernel( + params->tuning_ctx, params->Stream(), + params->c, static_cast(fast_gelu_input_length), + params->bias, static_cast(bias_length), + params->c); +} + +template +class GemmFastGeluTunableOp : public TunableOp> { + public: + GemmFastGeluTunableOp() { + this->RegisterOp(GemmFastGeluUnfused); +#ifdef USE_COMPOSABLE_KERNEL + for (auto&& [_, op] : GetCKGemmAddFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + for (auto&& [_, op] : GetCKGemmFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif + +#ifdef USE_HIPBLASLT + for (auto&& [_, op] : GetHipBlasLtGemmFastGeluTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif + } +}; + +} // namespace internal +} // namespace blas +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu new file mode 100644 index 0000000000000..09a6550549614 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -0,0 +1,530 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/rocm/bert/group_query_attention.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" +#include "contrib_ops/rocm/bert/rotary_embedding_impl.h" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" + +#ifdef USE_COMPOSABLE_KERNEL_CK_TILE +#include "ck_tile/core/numeric/integer.hpp" +#include "fmha_fwd.hpp" +#endif + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ + .MayInplace(3, 1) \ + .MayInplace(4, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 6), \ + GroupQueryAttention); + +// REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +// REGISTER_KERNEL_TYPED(BFloat16) + +template +std::string GetCkFmhaDataTypeString(); + +template <> +std::string GetCkFmhaDataTypeString() { + return "fp16"; +} + +template <> +std::string GetCkFmhaDataTypeString() { + return "bf16"; +} + +__global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int inc) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = seqlens[idx] + inc; + } +} + +Status LaunchSeqlensInc(hipStream_t stream, const int* seqlens, int* out, int num_elems, int inc) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqlens_inc_kernel<<>>(seqlens, out, num_elems, inc); + return HIP_CALL(hipGetLastError()); +} + +__global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = idx * length_per_seq; + } + if (idx == 0) { + out[num_elems] = num_elems * length_per_seq; + } +} + +Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int length_per_seq) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqstart_init_kernel<<>>(out, num_elems, length_per_seq); + return HIP_CALL(hipGetLastError()); +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, + const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + if (s < seqlens_k[b] + 1) { + position_ids[tid] = s; + } else { + position_ids[tid] = 1; + } + } +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + position_ids[tid] = seqlens_k[tid]; + } +} + +// Convert seqlens_k to position_ids +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, + int64_t* position_ids, hipStream_t stream, const int max_threads_per_block) { + const int seqlen = parameters.sequence_length; + const int batch_size = parameters.batch_size; + const int threads = max_threads_per_block; + const int blocks = (batch_size * seqlen + threads - 1) / threads; + if (parameters.is_first_prompt) { + SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else { + SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); + } + return HIP_CALL(hipGetLastError()); +} + +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : RocmKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_past_bsnh_ = false; + is_unidirectional_ = true; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + scale_ = info.GetAttrOrDefault("scale", 0.0f); +} + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template +Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { +#if USE_COMPOSABLE_KERNEL_CK_TILE + auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); + const Tensor* query = ctx->Input(0); + const Tensor* key = ctx->Input(1); + const Tensor* value = ctx->Input(2); + const Tensor* past_key = ctx->Input(3); + const Tensor* past_value = ctx->Input(4); + const Tensor* seqlens_k = ctx->Input(5); + const Tensor* total_seqlen = ctx->Input(6); + const Tensor* cos_cache = ctx->Input(7); + const Tensor* sin_cache = ctx->Input(8); + + auto& device_prop = GetDeviceProp(); + std::call_once( + arch_checking_, + [](const hipDeviceProp_t& device_prop) { + if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && + std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " + << "CDNA2 and CDNA3 archs."; + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention running on an unsuppoted GPU may result in " + << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; + } + }, + device_prop); + + GroupQueryAttentionParameters parameters; + using HipT = typename ToHipType::MappedType; + + const int max_thr_per_blk = device_prop.maxThreadsPerBlock; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶meters, + num_heads_, + kv_num_heads_, + seqlens_k, + total_seqlen, + is_past_bsnh_, + scale_, + max_thr_per_blk)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + // parameters.zeros_count = kZerosCount; + // parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context->OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + Tensor* output = ctx->Output(0, output_shape); + Strides output_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + + int4 past_shape; + std::vector present_dims; + Strides present_strides; + Strides past_strides; + if (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + past_shape = { + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size}; + past_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size); + present_dims = { + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size}; + present_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + } else { // BNSH + past_shape = { + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size}; + past_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size); + present_dims = { + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size}; + present_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size); + } + TensorShape present_shape(present_dims); + Tensor* present_key = ctx->Output(1, present_shape); + Tensor* present_value = ctx->Output(2, present_shape); + + Strides query_strides; + Strides key_strides; + Strides value_strides; + int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; // BNSH coord + const HipT* query_ptr = reinterpret_cast(query->DataRaw()); + const HipT* key_ptr; + const HipT* value_ptr; + if (!parameters.is_packed_qkv) { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); + value_strides = key_strides; + key_ptr = reinterpret_cast(key->DataRaw()); + value_ptr = reinterpret_cast(value->DataRaw()); + } else { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + value_strides = query_strides; + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key_ptr = query_ptr + key_offset; + value_ptr = key_ptr + value_offset; + } + + IAllocatorUniquePtr rotary_q_tmp; + IAllocatorUniquePtr rotary_k_tmp; + if (parameters.do_rotary) { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); + + rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); + rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); + auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, + reinterpret_cast(seqlens_k->DataRaw()), + reinterpret_cast(rotary_position_ids_tmp.get()), + hip_stream, max_thr_per_blk)); + // Launch rotary embedding kernel + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_q_tmp.get(), query_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, + query_strides.ForBNSHCoord(), + rotary_q_strides.ForBNSHCoord())); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.kv_num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, + key_strides.ForBNSHCoord(), + rotary_k_strides.ForBNSHCoord())); + query_ptr = reinterpret_cast(rotary_q_tmp.get()); + key_ptr = reinterpret_cast(rotary_k_tmp.get()); + query_strides = rotary_q_strides; + key_strides = rotary_k_strides; + } + + const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; + IAllocatorUniquePtr seqlens_k_tmp; + + // build present kv cache + auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); + auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); + if (parameters.is_first_prompt) { + // copy prompt kv to present kv + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); + const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); + parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: + if (!parameters.kv_share_buffer) { + // copy past to present, + // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are + // not the same, aka, can not be as simple as strided + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_key_ptr, past_shape, past_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + // In the case of share buffer + ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); + ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); + } + // then append new kv to present + size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + + // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. + // we should call fmha with total sequence lengths + seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); + seqlens_k_ptr = seqlens_k_tmp.get(); + } + static_assert(std::is_same_v); + + const float scale = parameters.scale == 0.0f + ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + bias_enum bias_type = bias_enum::no_bias; + + mask_info mask = [&]() { + if (local_window_size_ != -1) { + mask_info ret; + ret.type = mask_enum::window_generic; + ret.left = local_window_size_; + ret.right = parameters.is_unidirectional ? 0 : -1; + // ret.x = kv_sequence_length - (sequence_length - ret.left); + // ret.y = sequence_length + (ret.right - kv_sequence_length); + return ret; + } + + if (parameters.is_first_prompt && is_unidirectional_) { + return mask_info::decode("t", sequence_length, kv_sequence_length); + } + + return mask_info::decode("0", sequence_length, kv_sequence_length); + }(); + + auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_q_tmp.get(), batch_size, + query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_k_tmp.get(), batch_size, + present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); + + fmha_fwd_args args{ + query_ptr, + present_key->DataRaw(), + present_value->DataRaw(), + nullptr, // bias, alibi/element + nullptr, // lse, logsumexp buffer + output->MutableDataRaw(), + seqstart_q_tmp.get(), // seqstart_q_ptr, for group mode + seqstart_k_tmp.get(), // seqstart_k_ptr, for group mode + seqlens_k_ptr, // seqlen_k_ptr, for group mode + sequence_length, // seqlen_q, for batch mode + kv_sequence_length, // seqlen_k, for batch mode + parameters.batch_size, // batch + parameters.sequence_length, // max_seqlen_q + parameters.head_size, // hdim_q + parameters.head_size, // hdim_v + parameters.num_heads, + parameters.kv_num_heads, + scale, + 1.0f, // scale_p of squant, useless + 1.0f, // scale_o of squant, useless + static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S + batch_size, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 + static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S + static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N + 0, // nhead_stride_bias + batch_size, // nhead_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B + static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B + 0, // batch_stride_bias + num_heads * batch_size, // batch_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B + mask.left, // window_size_left + mask.right, // window_size_right + static_cast(mask.type)}; + +#if 0 + std::cout + << "\n sequence_length:" << sequence_length + << "\n kv_sequence_length:" << kv_sequence_length + << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache + << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; + + std::cout + << "\n q_ptr:" << args.q_ptr + << "\n k_ptr:" << args.k_ptr + << "\n v_ptr:" << args.v_ptr + << "\n bias_ptr:" << args.bias_ptr + << "\n lse_ptr:" << args.lse_ptr + << "\n o_ptr:" << args.o_ptr + << "\n seqstart_q_ptr:" << args.seqstart_q_ptr + << "\n seqstart_k_ptr:" << args.seqstart_k_ptr + << "\n seqlen_k_ptr:" << args.seqlen_k_ptr + << "\n seqlen_q:" << args.seqlen_q + << "\n seqlen_k:" << args.seqlen_k + << "\n batch:" << args.batch + << "\n max_seqlen_q:" << args.max_seqlen_q + << "\n hdim_q:" << args.hdim_q + << "\n hdim_v:" << args.hdim_v + << "\n nhead_q:" << args.nhead_q + << "\n nhead_k:" << args.nhead_k + << "\n scale_s:" << args.scale_s + << "\n scale_p:" << args.scale_p + << "\n scale_o:" << args.scale_o + << "\n stride_q:" << args.stride_q + << "\n stride_k:" << args.stride_k + << "\n stride_v:" << args.stride_v + << "\n stride_bias:" << args.stride_bias + << "\n stride_o:" << args.stride_o + << "\n nhead_stride_q:" << args.nhead_stride_q + << "\n nhead_stride_k:" << args.nhead_stride_k + << "\n nhead_stride_v:" << args.nhead_stride_v + << "\n nhead_stride_bias:" << args.nhead_stride_bias + << "\n nhead_stride_lse:" << args.nhead_stride_lse + << "\n nhead_stride_o:" << args.nhead_stride_o + << "\n batch_stride_q:" << args.batch_stride_q + << "\n batch_stride_k:" << args.batch_stride_k + << "\n batch_stride_v:" << args.batch_stride_v + << "\n batch_stride_bias:" << args.batch_stride_bias + << "\n batch_stride_lse:" << args.batch_stride_lse + << "\n batch_stride_o:" << args.batch_stride_o + << "\n window_size_left:" << args.window_size_left + << "\n window_size_right:" << args.window_size_right + << "\n mask_type:" << args.mask_type + << std::endl; +#endif + + fmha_fwd_traits traits{ + parameters.head_size, + parameters.head_size, // v head size + GetCkFmhaDataTypeString(), + !parameters.is_first_prompt, // true, // is_group_mode + true, // is_v_rowmajor ? dim is fastest : seq is fastest + mask.type, + bias_type, + false, // has_lse + false, // do_fp8_static_quant, aka, squant + }; + + ck_tile::stream_config stream_config{ + hip_stream, + false // time_kernel + }; + + auto duration = fmha_fwd(traits, args, stream_config); + if (duration < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "fmha_fwd internal error"); + } + HIP_RETURN_IF_ERROR(hipGetLastError()); + + return Status::OK(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tile to be enabled"); +#endif +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h new file mode 100644 index 0000000000000..ce0de1f761aa5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class GroupQueryAttention final : public RocmKernel { + public: + GroupQueryAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool is_unidirectional_; + bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; + float scale_; + + private: + static std::once_flag arch_checking_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh new file mode 100644 index 0000000000000..2eeb7c3e8f279 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh @@ -0,0 +1,270 @@ +#include "hip/hip_runtime.h" +/* + The implementation of this file is based on bert plugins in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/shared_inc/rocm_call.h" + +using namespace onnxruntime::rocm; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +__device__ inline T Rsqrt(const T& x); + +template <> +__device__ inline float Rsqrt(const float& x) { + return rsqrtf(x); +} + +template <> +__device__ inline half Rsqrt(const half& x) { +#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) + return hrsqrt(x); +#else + return half(rsqrtf(static_cast(x))); +#endif +} + +__device__ inline half2 AddHalf2(const half2 a, const half2 b) { +#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) + return __hadd2(a, b); +#else + return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y)); +#endif +} + +struct KeyValuePairSum { + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, + const hipcub::KeyValuePair& b) { + return hipcub::KeyValuePair(a.key + b.key, a.value + b.value); + } + + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, + const hipcub::KeyValuePair& b) { + const half2 a2 = __halves2half2(a.key, a.value); + const half2 b2 = __halves2half2(b.key, b.value); + const half2 res = AddHalf2(a2, b2); + return hipcub::KeyValuePair(__low2half(res), __high2half(res)); + } + + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, + const hipcub::KeyValuePair& b) { + return hipcub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); + } +}; + +template +__device__ inline void LayerNorm( + const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, + const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + + using BlockReduce = hipcub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U mu; // mean + __shared__ U rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); + } + __syncthreads(); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const U val = static_cast(output[idx]); + const U g = static_cast(gamma[i]); + const U b = (nullptr == beta) ? U(0.f) : static_cast(beta[i]); + output[idx] = static_cast(g * (val - mu) * rsigma + b); + } +} + +template +__device__ inline void SimplifiedLayerNorm( + const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U rsigma; // 1 / std.dev. + + const U sum = BlockReduce(temp_storage).Sum(thread_data); + + if (threadIdx.x == 0) { + rsigma = Rsqrt(sum + epsilon); + } + __syncthreads(); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const U val = static_cast(output[idx]); + const U g = static_cast(gamma[i]); + output[idx] = static_cast(g * val * rsigma); + } +} + +template +__device__ inline void SimplifiedLayerNormVec( + const U& thread_data, const int ld, const int offset, const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + using VecV = aligned_vector; + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U rsigma; // 1 / std.dev. + + const U sum = BlockReduce(temp_storage).Sum(thread_data); + + if (threadIdx.x == 0) { + rsigma = Rsqrt(sum + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { + int idx = offset + i; + const VecV gamma_v = *reinterpret_cast(gamma + i); + VecV output_v = *reinterpret_cast(output + idx); + +#pragma unroll + for (int k = 0; k < ILP; k++) { + output_v.val[k] = U(gamma_v.val[k]) * U(output_v.val[k]) * rsigma; + } + *(reinterpret_cast(output + idx)) = output_v; + } + } +} + +template +__device__ inline void LayerNormVec( + const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const V* beta, + const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + using VecV = aligned_vector; + using BlockReduce = hipcub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U mu; // mean + __shared__ U rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { + int idx = offset + i; + const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + i) : VecV(); + const VecV gamma_v = *reinterpret_cast(gamma + i); + VecV output_v = *reinterpret_cast(output + idx); + +#pragma unroll + for (int k = 0; k < ILP; k++) { + output_v.val[k] = (beta != nullptr) ? U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma + U(beta_v.val[k]) : U(gamma_v.val[k]) * (U(output_v.val[k]) - mu) * rsigma; + } + *(reinterpret_cast(output + idx)) = output_v; + } + } +} + +template +__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair& thread_data, + const int ld, const int idx, const V* beta, const V* gamma, + const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + // Small settings: the block covers the leading dimension TPB >= ld. The input + // value is available in a register + using VecV = aligned_vector; + using BlockReduce = hipcub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U mu; // mean + __shared__ U rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const hipcub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + const VecV beta_v = (beta != nullptr) ? *reinterpret_cast(beta + threadIdx.x * ILP) : VecV(); + const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); + VecV output_v; + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = (beta != nullptr) ? U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma + U(beta_v.val[i]) : U(gamma_v.val[i]) * (U(input_v[i]) - mu) * rsigma; + } + *(reinterpret_cast(output + idx)) = output_v; + } +} + +template +__device__ inline void SimplifiedLayerNormSmall(const T* input_v, const U& thread_data, const int ld, const int idx, + const V* gamma, const U epsilon, V* output) { + // Assuming thread_data is already divided by ld + // Small settings: the block covers the leading dimension TPB >= ld. The input + // value is available in a register + using VecV = aligned_vector; + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ U rsigma; // 1 / std.dev. + + const U sum = BlockReduce(temp_storage).Sum(thread_data); + + if (threadIdx.x == 0) { + rsigma = Rsqrt(sum + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + const VecV gamma_v = *reinterpret_cast(gamma + threadIdx.x * ILP); + VecV output_v; + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = U(gamma_v.val[i]) * U(input_v[i]) * rsigma; + } + *(reinterpret_cast(output + idx)) = output_v; + } +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu new file mode 100644 index 0000000000000..5d4ef53b8ba97 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/multihead_attention.h" + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" +#include "core/platform/env_var_utils.h" +#include "core/providers/rocm/rocm_common.h" + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_MHA_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MultiHeadAttention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + MultiHeadAttention) + +REGISTER_MHA_KERNEL_TYPED(float); +REGISTER_MHA_KERNEL_TYPED(MLFloat16); + +static constexpr int kPastSequenceLengthInputIndex = 7; +static constexpr int kBeamWidthInputIndex = 8; +static constexpr int kPastInputIndex = 5; +static constexpr int kPresentOutputIndex = 1; + +#define REGISTER_DMMHA_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DecoderMaskedMultiHeadAttention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .MayInplace(kPastInputIndex + 1, kPresentOutputIndex + 1) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \ + .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \ + MultiHeadAttention) + +REGISTER_DMMHA_KERNEL_TYPED(float); +REGISTER_DMMHA_KERNEL_TYPED(MLFloat16); + +template +MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) + : RocmKernel(info), + attn_type_(info.node().OpType() == "DecoderMaskedMultiHeadAttention" ? kDecoderMaskedMultiHeadAttention + : kMultiHeadAttention) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + + mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + + past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + + using HipT = typename ToHipType::MappedType; + using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; + tunable_op_ = std::make_shared(); +} + +template +Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE( + GetTuningContext()->IsTunableOpEnabled(), + "MultiHeadAttention of ROCm EP is only supported if tunable op is used and tuning is enabled."); + + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + + const Tensor* bias{}; + const Tensor* key_padding_mask{}; + const Tensor* attention_bias{}; + const Tensor* past_key{}; + const Tensor* past_value{}; + const Tensor* past_seq_len{}; + + const Tensor* cache_indirection = nullptr; + + if (attn_type_ == kMultiHeadAttention) { + bias = context->Input(3); + key_padding_mask = context->Input(4); + attention_bias = context->Input(5); + past_key = context->Input(6); + past_value = context->Input(7); + } else if (attn_type_ == kDecoderMaskedMultiHeadAttention) { + key_padding_mask = context->Input(3); + attention_bias = context->Input(4); + past_key = context->Input(5); + past_value = context->Input(6); + past_seq_len = context->Input(kPastSequenceLengthInputIndex); + // const Tensor* beam_width = context->Input(8); // NOTE: not used + // const Tensor* cache_indirection = context->Input(9); // TODO: should not present for ROCm EP + bias = context->Input(10); + } + + if (nullptr != bias) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "qkv_bias is not supported on ROCm EP. " + "User should fuse the qkv bias to qkv projection instead."); + } + + auto& device_prop = GetDeviceProp(); + RocmAttentionParameters attn; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, + key, + value, + bias, + key_padding_mask, + attention_bias, + past_key, + past_value, + cache_indirection, + past_seq_len, + &attn, /* parameters */ + num_heads_, + mask_filter_value_, + scale_, + is_unidirectional_, + past_present_share_buffer_, + attn_type_, + device_prop.maxThreadsPerBlock)); + + if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(attn.batch_size); + output_shape[1] = static_cast(attn.sequence_length); + output_shape[2] = static_cast(attn.v_hidden_size); + Tensor* output = context->Output(0, output_shape); + + std::vector present_dims{ + attn.batch_size, + attn.num_heads, + past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, + attn.head_size, + }; + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(1, present_shape); + Tensor* present_value = context->Output(2, present_shape); + + ORT_RETURN_IF_ERROR(ClassifyAttentionMode( + attn_type_, &attn, + /*qkv=*/{query, key, value}, + /*past=*/{past_key, past_value}, + /*present=*/{present_key, present_value})); + + using HipT = typename ToHipType::MappedType; + using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; + auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn); + auto workspace = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + hipStream_t stream = Stream(context); + if (nullptr != present_key) { // process past present concat + Strides dst_strides; + + int4 past_shape; + Strides past_src_strides; + const HipT* past_key_src; + const HipT* past_value_src; + HipT* past_key_dst{}; + HipT* past_value_dst{}; + + int4 add_shape; + Strides add_src_strides; + const HipT* add_key_src = reinterpret_cast(key->DataRaw()); + const HipT* add_value_src = reinterpret_cast(value->DataRaw()); + HipT* add_key_dst; + HipT* add_value_dst; + + if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH || + attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { + dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + + past_shape = {attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; + past_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); + past_key_src = reinterpret_cast(past_key->DataRaw()); + past_value_src = reinterpret_cast(past_value->DataRaw()); + past_key_dst = reinterpret_cast(present_key->MutableDataRaw()); + past_value_dst = reinterpret_cast(present_value->MutableDataRaw()); + + if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH) { + add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); + } else if (attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { + add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); + } + } else if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH || + attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { + dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + + if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH) { + add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); + } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { + add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); + } + } else if ( + attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || + attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || + attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH || + attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { + dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); + + if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH) { + add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); + } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { + add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "past present concatenation is not implemented for attention mode ", attn.mode); + } + add_shape = {attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size}; // kernel in coord (b,n,s,h) + add_key_dst = reinterpret_cast(present_key->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + add_value_dst = reinterpret_cast(present_value->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + + if (past_key_dst) { + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, past_key_src, past_shape, past_src_strides.ForBNSHCoord(), + past_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } + if (past_value_dst) { + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, past_value_src, past_shape, past_src_strides.ForBNSHCoord(), + past_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } + + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, add_key_src, add_shape, add_src_strides.ForBNSHCoord(), + add_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, add_value_src, add_shape, add_src_strides.ForBNSHCoord(), + add_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } + + GemmSoftmaxGemmPermuteParams params; + params.tuning_ctx = GetTuningContext(); + params.stream = context->GetComputeStream(); + params.handle = GetHipblasHandle(context); + params.attention = &attn; + params.device_prop = &device_prop; + params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_; + std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews( + &attn, + nullptr == query ? nullptr : reinterpret_cast(query->DataRaw()), + nullptr == key ? nullptr : reinterpret_cast(key->DataRaw()), + nullptr == value ? nullptr : reinterpret_cast(value->DataRaw()), + nullptr == present_key ? nullptr : reinterpret_cast(present_key->DataRaw()), + nullptr == present_value ? nullptr : reinterpret_cast(present_value->DataRaw())); + params.out_buffer = reinterpret_cast(output->MutableDataRaw()); + + if (key_padding_mask != nullptr) { + params.mask_index_buffer = key_padding_mask->Data(); + params.mask_index_dims = key_padding_mask->Shape().AsShapeVector(); + } + + if (attention_bias != nullptr) { + params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); + } + + params.workspace_buffer = reinterpret_cast(workspace.get()); + return (*std::static_pointer_cast(tunable_op_))(¶ms); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h new file mode 100644 index 0000000000000..1d676d7a7bcac --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/rocm/rocm_kernel.h" +#include "contrib_ops/rocm/bert/attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class MultiHeadAttention final : public RocmKernel { + public: + MultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + AttentionType attn_type_; + int num_heads_; // number of attention heads + float mask_filter_value_; + float scale_; + bool past_present_share_buffer_{false}; + bool is_unidirectional_{false}; + + // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: + // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. + // 2. We don't want to construct the object repeatly (which is expansive) during Compute. + std::shared_ptr tunable_op_; +}; + +template +class DecoderMaskedMultiHeadAttention final : public RocmKernel { + public: + DecoderMaskedMultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + AttentionType mha_type; + int num_heads_; // number of attention heads + float mask_filter_value_; + float scale_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc new file mode 100644 index 0000000000000..9e649fb591896 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/skip_layer_norm.h" + +#include "core/providers/rocm/rocm_common.h" +#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" +#include "contrib_ops/rocm/bert/transformer_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipSimplifiedLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); + ORT_ENFORCE(epsilon_ >= 0); +} + +template +Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input = ctx->Input(0); + const Tensor* skip = ctx->Input(1); + const Tensor* gamma = ctx->Input(2); + + const Tensor* beta = Simplified ? nullptr : ctx->Input(3); + const Tensor* bias = Simplified ? ctx->Input(3) : ctx->Input(4); + + Tensor* output = ctx->Output(0, input->Shape()); + + // For inferencing, we support one more optional output which is the sum + // of the input and skip tensors + Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); + + if (input->Shape() != skip->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip is expected to have same shape as input"); + } + + if (input->Shape().Size() == 0) { + return Status::OK(); + } + + const auto& input_dims = input->Shape().GetDims(); + size_t input_dims_size = input_dims.size(); + if (input_dims_size != 3 && input_dims_size != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 or 2 dimensions, got ", input_dims_size); + } + + int hidden_size = static_cast(input_dims[input_dims_size - 1]); + + const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "gamma is expected to have 1 dimension, got ", gamma_dims.size()); + } + if (gamma_dims[0] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of gamma and input does not match"); + } + + if (nullptr != beta) { + const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "beta is expected to have 1 dimension, got ", beta_dims.size()); + } + if (beta_dims[0] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of beta and input does not match"); + } + } + + if (nullptr != bias) { + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of bias and input does not match"); + } + } + + int64_t element_count = input->Shape().Size(); + typedef typename ToHipType::MappedType HipT; + + return LaunchSkipLayerNormKernel( + GetTuningContext(), + ctx->GetComputeStream(), + reinterpret_cast(output->MutableData()), + skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, + reinterpret_cast(input->Data()), + reinterpret_cast(skip->Data()), + reinterpret_cast(gamma->Data()), + (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, + epsilon_, + hidden_size, + static_cast(element_count)); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h new file mode 100644 index 0000000000000..02228bc59cedc --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class SkipLayerNorm final : public RocmKernel { + public: + SkipLayerNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + float epsilon_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu new file mode 100644 index 0000000000000..8387c49a3310b --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu @@ -0,0 +1,86 @@ +#include "hip/hip_runtime.h" +/* + The implementation of this file is based on skipLayerNorm plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Modifications: Add SkipLayerNormKernelVec to +// leverage vectorized load/write. +// and templatize ComputeSkipLayerNorm for different +// data types. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" + +#include + +#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" +#include "contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, + const T* skip, const V* gamma, const V* beta, const T* bias, float epsilon, int ld, int element_count) { + // this must be true because element_count is the total size of the tensor + assert(element_count % ld == 0); + + SkipLayerNormParams params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, + gamma, beta, bias, epsilon, ld, element_count); + + if (tuning_ctx->IsTunableOpEnabled()) { + static SkipLayerNormTunableOp op; + return op(¶ms); + } + + return SkipLayerNormStaticSelection(¶ms); +} + +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, + const float* skip, const float* gamma, const float* beta, + const float* bias, float epsilon, int ld, + int element_count); + +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, + const half* skip, const half* gamma, const half* beta, + const half* bias, float epsilon, int ld, + int element_count); + +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* skip_input_bias_add_output, const float* input, + const float* skip, const float* gamma, const float* beta, + const float* bias, float epsilon, int ld, + int element_count); + +template Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* skip_input_bias_add_output, const half* input, + const half* skip, const half* gamma, const half* beta, + const half* bias, float epsilon, int ld, + int element_count); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h new file mode 100644 index 0000000000000..5e2a92447d2f5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +Status LaunchSkipLayerNormKernel( + RocmTuningContext* tuning, + Stream* stream, + V* output, // output tensor + T* skip_input_bias_add_output, // optional output tensor + const T* input, // input tensor + const T* skip, // skip tensor + const V* gamma, // Layer normalization gamma tensor + const V* beta, // Layer normalization beta tensor + const T* bias, // Layer normalization beta tensor + float epsilon, // Layer normalization epsilon + int hidden_size, // hidden size, it is the leading dimension (ld) + int element_count // number of elements in input tensor +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h new file mode 100644 index 0000000000000..fcfbc8969e498 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "contrib_ops/rocm/bert/layer_norm.cuh" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +T maybe2half(float x); + +template <> +float maybe2half(float x) { + return x; +} + +template <> +half maybe2half(float x) { + return __float2half_rn(x); +} + +template +__global__ void SkipLayerNormKernel( + const int ld, const T* input, const T* skip, const V* beta, const V* gamma, const T* bias, + const U epsilon, V* output, T* skip_input_bias_add_output) { + const U reverse_ld = U(1.f / ld); + const int offset = blockIdx.x * ld; + + KeyValuePairSum pair_sum; + // reduce x and x^2 + hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const U val = (bias == nullptr) ? static_cast(input[idx]) + static_cast(skip[idx]) : static_cast(input[idx]) + static_cast(skip[idx]) + static_cast(bias[i]); + const U rldval = reverse_ld * val; + thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); + + if (skip_input_bias_add_output != nullptr) { + skip_input_bias_add_output[idx] = static_cast(val); + } + + output[idx] = static_cast(val); + } + + if constexpr (Simplified) { + SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); + return; + } + + LayerNorm(thread_data, ld, offset, beta, gamma, epsilon, output); +} + +// Vectorized kernel +template +__global__ void SkipLayerNormKernelVec( + const int ld, const T* input, const T* skip, const V* beta, const V* gamma, + const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, + bool hasBias, bool hasSkipInputBiasAdditionOutput) { + const U reverse_ld = U(1.f / ld); + const int offset = blockIdx.x * ld; + + KeyValuePairSum pair_sum; + // reduce x and x^2 + hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); + + using VecT = aligned_vector; + using VecV = aligned_vector; + if (threadIdx.x * ILP < ld) { + for (int i = threadIdx.x * ILP; i < ld; i += TPB * ILP) { + int idx = offset + i; + + const VecT input_v = *reinterpret_cast(input + idx); + const VecT skip_v = *reinterpret_cast(skip + idx); + const VecT bias_v = hasBias ? *reinterpret_cast(bias + i) : VecT(); + VecT skip_input_bias_add_output_v, output_v; + +#pragma unroll + for (int k = 0; k < ILP; k++) { + const U val = hasBias ? static_cast(input_v.val[k]) + static_cast(skip_v.val[k]) + static_cast(bias_v.val[k]) : static_cast(input_v.val[k]) + static_cast(skip_v.val[k]); + const U rldval = reverse_ld * val; + + if (hasSkipInputBiasAdditionOutput) { + skip_input_bias_add_output_v.val[k] = static_cast(val); + } + thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); + output_v.val[k] = static_cast(val); + } + + if (hasSkipInputBiasAdditionOutput) { + *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; + } + + *(reinterpret_cast(output + idx)) = output_v; + } + } + + if constexpr (Simplified) { + SimplifiedLayerNormVec(thread_data.value, ld, offset, gamma, epsilon, output); + return; + } + + LayerNormVec(thread_data, ld, offset, beta, gamma, epsilon, output); +} + +// Vectorized kernel +template +__global__ void SkipLayerNormKernelSmall( + const int ld, const T* input, const T* skip, const V* beta, const V* gamma, + const T* bias, const U epsilon, V* output, T* skip_input_bias_add_output, + bool hasBias, bool hasSkipInputBiasAdditionOutput) { + const U rld = U(1.f / ld); + const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld + + using VecT = aligned_vector; + hipcub::KeyValuePair thread_data(U(0.f), U(0.f)); + + VecT input_v; + if (ILP * threadIdx.x < ld) { + input_v = *reinterpret_cast(input + idx); + const VecT skip_v = *reinterpret_cast(skip + idx); + const VecT bias_v = hasBias ? *reinterpret_cast(bias + threadIdx.x * ILP) : VecT(); + VecT skip_input_bias_add_output_v; + + U rldval_sum = U(0.f); + U rldvalsq_sum = U(0.f); +#pragma unroll + for (int i = 0; i < ILP; i++) { + const U val = hasBias ? static_cast(input_v.val[i]) + static_cast(skip_v.val[i]) + static_cast(bias_v.val[i]) : static_cast(input_v.val[i]) + static_cast(skip_v.val[i]); + + if (hasSkipInputBiasAdditionOutput) { + skip_input_bias_add_output_v.val[i] = static_cast(val); + } + + const U rldval = rld * val; + rldval_sum += rldval; + rldvalsq_sum += rldval * val; + input_v.val[i] = static_cast(val); + } + + if (hasSkipInputBiasAdditionOutput) { + *(reinterpret_cast(skip_input_bias_add_output + idx)) = skip_input_bias_add_output_v; + } + + thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); + } + + if constexpr (Simplified) { + SimplifiedLayerNormSmall(input_v.val, thread_data.value, ld, idx, gamma, epsilon, output); + return; + } + + LayerNormSmall(input_v.val, thread_data, ld, idx, beta, gamma, epsilon, output); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h new file mode 100644 index 0000000000000..0391704ce1c56 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_tunable_op.h @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +using onnxruntime::rocm::CeilDiv; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +struct SkipLayerNormParams : OpParams { + SkipLayerNormParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, V* output, T* skip_input_bias_add_output, const T* input, + const T* skip, const V* gamma, const V* beta, + const T* bias, float epsilon, int ld, int element_count) + : OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {} + + std::string Signature() const override { + std::string sig = std::to_string(ld) + "_" + std::to_string(element_count); + return sig; + } + + V* output; + T* skip_input_bias_add_output; + const T* input; + const T* skip; + const V* gamma; + const V* beta; + const T* bias; + float epsilon; + int ld; + int element_count; +}; + +template +Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { + // Loosen the hard constraint for ld (hidden_size) to include more possible *Small kernels, + // which could offer better performance in some combinations of ThreadsPerBlock and VecSize. + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !((params->ld <= 8192 && params->ld % VecSize == 0 && + params->ld <= ThreadsPerBlock * VecSize && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))); + SkipLayerNormKernelSmall<<element_count, params->ld)), + dim3(ThreadsPerBlock), + 0, params->StreamHandle()>>>( + params->ld, params->input, params->skip, + params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, + (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); + return HIP_CALL(hipGetLastError()); +} + +template +Status SkipLayerNormRegularOp(const SkipLayerNormParams* params) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !((params->ld > 0 && params->ld % VecSize == 0 && + (params->ld >= ThreadsPerBlock * VecSize || + (params->ld < GPU_WARP_SIZE && params->ld > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize))))); + SkipLayerNormKernelVec<<element_count, params->ld)), + dim3(ThreadsPerBlock), + 0, params->StreamHandle()>>>( + params->ld, params->input, params->skip, + params->beta, params->gamma, params->bias, static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, + (params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true); + return HIP_CALL(hipGetLastError()); +} + +template +Status SkipLayerNormStaticSelection(const SkipLayerNormParams* params) { + bool hasBias = (params->bias == nullptr) ? false : true; + bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true; + const int grid_size = params->element_count / params->ld; + const int block_size = 256; + +#define LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(ELEMENTS, TPB, ILP) \ + if (params->ld <= ELEMENTS) { \ + SkipLayerNormKernelSmall<<StreamHandle()>>>( \ + params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, \ + static_cast(params->epsilon), params->output, params->skip_input_bias_add_output, \ + hasBias, hasSkipInputBiasAdditionOutput); \ + break; \ + } + if (0 == (params->ld % 4)) { + do { + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 32, 2) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 32, 4) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 96, 4) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(768, 192, 4) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(1024, 256, 4) + + SkipLayerNormKernel<<StreamHandle()>>>( + params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, + static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); + } while (0); + } else { + do { + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(32, 32, 1) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(64, 64, 1) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(128, 128, 1) + LAUNCH_SKIPLAYERNORM_SMALL_FORWARD(384, 384, 1) + + SkipLayerNormKernel<<StreamHandle()>>>( + params->ld, params->input, params->skip, params->beta, params->gamma, params->bias, + static_cast(params->epsilon), params->output, params->skip_input_bias_add_output); + } while (0); + } + return HIP_CALL(hipPeekAtLastError()); +} // namespace rocm + +#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ + this->RegisterOp(name); \ + this->RegisterOp(name); \ + this->RegisterOp(name); \ + this->RegisterOp(name); \ + this->RegisterOp(name); + +#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 320) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 384) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 448) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 512) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 576) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 640) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 704) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 768) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 832) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 896) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 1024) + +template +class SkipLayerNormTunableOp : public TunableOp> { + public: + SkipLayerNormTunableOp() { + this->RegisterOp(SkipLayerNormStaticSelection); + ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmallOp) + ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormRegularOp) + + // NOTE: the 1st kernel is SkipLayerNorm Original implementation. + this->SetDefaultId(0); + } +}; + +#undef ADD_OP_FOR_ALL_VEC_SIZE +#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc b/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc new file mode 100644 index 0000000000000..6ae8d1202d462 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/transformer_common.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include "core/providers/shared_library/provider_api.h" // Include this otherwise Windows build complains Env::Default() missing +#include "core/platform/env_var_utils.h" +#include "contrib_ops/rocm/bert/transformer_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +// The environment variable is for testing purpose only, and it might be removed in the future. +// If you need some option in production, please file a feature request. +constexpr const char* kTransformerOptions = "ORT_TRANSFORMER_OPTIONS"; + +// Initialize the singleton instance +TransformerOptions TransformerOptions::instance; + +const TransformerOptions* TransformerOptions::GetInstance() { + if (!instance.initialized_) { + // We do not use critical section here since it is fine to initialize multiple times by different threads. + int value = ParseEnvironmentVariableWithDefault(kTransformerOptions, 0); + instance.Initialize(value); + + if (value > 0) + std::cout << "ORT_TRANSFORMER_OPTIONS: IsPrecisionMode=" << instance.IsPrecisionMode() + << ",DisablePersistentSoftmax=" << instance.DisablePersistentSoftmax() + << ",DisableHalf2=" << instance.DisableHalf2() + << std::endl; + } + + return &instance; +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/transformer_common.h b/onnxruntime/contrib_ops/rocm/bert/transformer_common.h new file mode 100644 index 0000000000000..6816b5b9d07ec --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/transformer_common.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/rocm_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +class TransformerOptions { + public: + static const TransformerOptions* GetInstance(); + + bool IsPrecisionMode() const { return is_precision_mode_; } + + bool DisablePersistentSoftmax() const { return disable_persistent_softmax_; } + + bool DisableHalf2() const { return disable_half2_; } + + void Initialize(int value) { + is_precision_mode_ = (value & 0x01) > 0; + disable_persistent_softmax_ = (value & 0x02) > 0; + disable_half2_ = (value & 0x04) > 0; + initialized_ = true; + } + + private: + // Default is false. If the mode is on, prefer precision than speed. + bool is_precision_mode_{false}; + + // Disable persistent softmax. + bool disable_persistent_softmax_{false}; + + // Disable half2 kernel. + bool disable_half2_{false}; + + bool initialized_{false}; + + static TransformerOptions instance; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh new file mode 100644 index 0000000000000..d0a0d09fcbae3 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#ifdef USE_COMPOSABLE_KERNEL +#include "core/providers/rocm/composable_kernel_common.h" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" +#endif // USE_COMPOSABLE_KERNEL + +#include "contrib_ops/rocm/diffusion/group_norm_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#ifdef USE_COMPOSABLE_KERNEL + +using onnxruntime::rocm::CKDataTypeAdaptor; + +// The SiLU function is a special case of Swish function, +// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: +// SiLU(x) = x * sigmoid(x) +// Swish(x) = x * sigmoid(bx) +// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. +using Silu = ck::tensor_operation::element_wise::Swish; +using Pass = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +template +auto GetCKGroupNormNHWCTypeStringAndOps() { + using XDataType = typename CKDataTypeAdaptor::type; + using YDataType = typename CKDataTypeAdaptor::type; + using SaveMeanInvStdDataType = typename CKDataTypeAdaptor::type; + using GammaDataType = float; + using BetaDataType = float; + + using Activation = std::conditional_t; + + std::vector>>> ret; + for (auto&& impl : internal::GetDeviceGroupNormInstances()) { + std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; + auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; + auto invoker = impl->MakeInvokerPointer(); + + auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by composable kernel."); + if constexpr (WithSilu) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !params->use_silu, "Silu version only support groupnorm with silu"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->use_silu, "Pass version only support groupnorm without silu"); + } + std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; + std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, + params->c, params->channels_per_group, 1}; + std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; + std::vector reduce_dims{1, 2, 4}; + + auto activation = Activation{}; + + auto arg = impl->MakeArgumentPointer(in_lengths, // lengths + in_out_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + in_out_strides, // yStrides + {0, 0}, // saveMeanStrides + {0, 0}, // saveInvStdStrides + reduce_dims, // reduceDims + params->epsilon, + params->src, + params->gamma, + params->beta, + params->dst, + nullptr, + nullptr, + activation); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_group_norm_op))); + } + return ret; +} +#endif // USE_COMPOSABLE_KERNEL + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh new file mode 100644 index 0000000000000..68f7d47282845 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifdef USE_COMPOSABLE_KERNEL +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" +#include "ck/utility/data_type.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +using F16 = ck::half_t; +using F32 = float; + +using Silu = ck::tensor_operation::element_wise::Swish; +using Pass = ck::tensor_operation::element_wise::PassThrough; + +using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface +using ck::tensor_operation::device::DeviceNormalizationFwdImpl; // the implementation + +// See https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1fefd82ed8/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp + +template +using device_normalization_f32_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl + // clang-format on + >; + +template +using device_normalization_f16_instances = + // clang-format off + std::tuple < + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl + // clang-format on + >; + +// Use this function to get implementation +template +std::vector>> +GetDeviceGroupNormInstances() { + return {}; +} + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F16, F32, F32, F16, F32, Silu, 5, 3>(); + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F16, F32, F32, F16, F32, Pass, 5, 3>(); + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F32, F32, F32, F32, F32, Silu, 5, 3>(); + +template <> +std::vector>> +GetDeviceGroupNormInstances< + F32, F32, F32, F32, F32, Pass, 5, 3>(); + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu new file mode 100644 index 0000000000000..ad191314e5e4c --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f16_instances{}); + + return instances; +} + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f16_instances{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu new file mode 100644 index 0000000000000..ceb53ed442abc --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_COMPOSABLE_KERNEL +#include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace onnxruntime { +namespace contrib { +namespace rocm { +namespace internal { + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f32_instances{}); + + return instances; +} + +template <> +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_normalization_f32_instances{}); + + return instances; +} + +} // namespace internal +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime +#endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h new file mode 100644 index 0000000000000..7cff640db2f34 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { + GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, + onnxruntime::Stream* ort_stream, + T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + float* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) + : OpParams(tuning_ctx, ort_stream), + GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, + num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} + + std::string Signature() const override { + std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; + std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; + std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; + std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; + std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + + std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + + skip_suffix + broadcast_suffix + bias_suffix; + return sig; + } +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu new file mode 100644 index 0000000000000..142aaf14e8d2d --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// The ROCM kernel is hipified from CUDA kernel. +#include "contrib_ops/rocm/diffusion/group_norm_impl.h" + +#include +#include "contrib_ops/rocm/diffusion/group_norm_common.h" +#include "contrib_ops/rocm/diffusion/group_norm_tunable_op.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +Status LaunchGroupNormKernel( + RocmTuningContext* tuning_ctx, + Stream* ort_stream, + T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, + reinterpret_cast(workspace), epsilon, batch_size, num_channels, + height, width, num_groups, use_silu, broadcast_skip, channels_per_block); + + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in ROCM does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); + } + + HIP_RETURN_IF_ERROR(hipMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); + + if (tuning_ctx->IsTunableOpEnabled()) { + static GroupNormNHWCTunableOp op; + return op(¶ms); + } + + return GroupNormNHWCStaticSelection(¶ms); +} + +template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, + half* add_out, const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); + +template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, + float* add_out, const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh new file mode 100644 index 0000000000000..c6ca16bfdfc80 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "contrib_ops/rocm/diffusion/group_norm_common.h" +#include "core/providers/rocm/triton_kernel.h" + +using namespace onnxruntime::rocm; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#ifdef USE_TRITON_KERNEL + +namespace { + +template +std::string GetGroupNormTritonGroupName() { + std::string ret = "GroupNormTriton_"; + std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; + ret += silu_suffix; + ret += GetDataTypeName(); + return ret; +} + +} // namespace + +template +auto GetTritonGroupNormNHWCTypeStringAndOps() { + std::vector>>> ret; + auto group_name = GetGroupNormTritonGroupName(); + auto* kernel_list = GetOrtTritonKernelByGroup(group_name); + if (kernel_list == nullptr) { + return ret; + } + + for (auto i : *kernel_list) { + // Check params match + auto* metadata = GetOrtTritonKernelMetadata(i); + auto block_size = metadata->constants.at("BLOCK_SIZE"); + auto hw_size = metadata->constants.at("HW_SIZE"); + auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, + "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", + params->channels_per_group, ")."); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); + if constexpr (WithSilu) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); + } + // Construct args for launch kernel + struct { + const void* src; + const void* skip; + const void* bias; + void* out; + void* add_out; + const void* gamma; + const void* beta; + int hw; + int c; + int c_per_group; + float eps; + bool has_skip; + bool has_bias; + bool broadcast_skip; + } args = { + (const void*)params->src, + (const void*)params->skip, + (const void*)params->bias, + (void*)params->dst, + (void*)params->skip_workspace, + (const void*)params->gamma, + (const void*)params->beta, + params->hw, + params->c, + params->channels_per_group, + params->epsilon, + params->skip != nullptr, + params->bias != nullptr, + params->broadcast_skip, + }; + + // Grid dim is (batch_count, groups, 1) + return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); + }; + ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); + } + return ret; +} + +#endif // USE_TRITON_KERNEL + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py new file mode 100644 index 0000000000000..5ba96ebc117f0 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -0,0 +1,135 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from itertools import product + +import triton +import triton.language as tl + + +@triton.jit +def group_norm_kernel( + input_ptr, + skip_ptr, + bias_ptr, + output_ptr, + add_out_ptr, + gamma_ptr, + beta_ptr, + img_size, + c, + c_per_group, + eps, + has_skip, + has_bias, + broadcast_skip, + BLOCK_SIZE: tl.constexpr, + HW_SIZE: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, +): + row_x = tl.program_id(0) + row_y = tl.program_id(1) + stride = img_size * c + input_ptr += row_x * stride + row_y * c_per_group + output_ptr += row_x * stride + row_y * c_per_group + gamma_ptr += row_y * c_per_group + beta_ptr += row_y * c_per_group + + cols = tl.arange(0, BLOCK_SIZE) + hw = tl.arange(0, HW_SIZE) + offsets = hw[:, None] * c + cols[None, :] + mask = (cols < c_per_group)[None, :] + + bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if has_skip: + add_out_ptr += row_x * stride + row_y * c_per_group + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group + if has_bias: + bias_ptr += row_y * c_per_group + bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + + # Calculate mean and variance + _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) + _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) + for i in range(tl.cdiv(img_size, HW_SIZE)): + x_ptr = input_ptr + i * HW_SIZE * c + a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip and not broadcast_skip: + s_ptr = skip_ptr + i * HW_SIZE * c + s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + a += s + if has_bias or broadcast_skip: + a += bias + _sum += a + _square_sum += a * a + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + tl.store(add_y_ptr + offsets, a, mask=mask) + + # Set axis=None (or leave it unspecified) to reduce all axes. + # TODO: In older Triton we have to reduce an axis at a time, but in our case + # for some configs it may have some issue when reducing sequentially along the axes. + group_mean = tl.sum(_sum, axis=None) / (img_size * c_per_group) + group_var = tl.sum(_square_sum, axis=None) / (img_size * c_per_group) - group_mean * group_mean + + rstd = 1 / tl.sqrt(group_var + eps) + + # Normalize and apply linear transformation + gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) + beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) + for i in range(tl.cdiv(img_size, HW_SIZE)): + y_ptr = output_ptr + i * HW_SIZE * c + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + else: + x_ptr = input_ptr + i * HW_SIZE * c + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + x_hat = (x - group_mean) * rstd + y = x_hat * gamma + beta + if ACTIVATION_SILU: + y *= tl.sigmoid(y) + tl.store(y_ptr + offsets, y, mask=mask) + + +# We can have more combinations of blocks and hw_sizes, e.g., +# blocks = [16, 32, 64, 128, 256, 512] +# hw_sizes = [8, 16, 32, 64, 128, 256, 512] +# but this will result in too many functions and slow down the compilation. +with_silu = [True, False] +dtypes = ["fp32", "fp16"] +blocks = [16, 32, 64, 128] +hw_sizes = [8, 16, 32, 64, 128, 256] +warps = [1, 2, 4, 8, 16] +name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" +group_pattern = "GroupNormTriton_{}_{}" + + +def get_function_table(): + func_table = [] + + for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): + silu_suffix = "Silu" if silu else "Pass" + name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) + group = group_pattern.format(silu_suffix, dtype) + sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) + kwargs = { + "num_warps": warp, + "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, + } + func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} + func_table.append(func_desc) + return func_table + + +if __name__ == "__main__": + func_table = get_function_table() + for func_desc in func_table: + print(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h new file mode 100644 index 0000000000000..e6831f764b418 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/rocm_common.h" +#include "contrib_ops/rocm/diffusion/group_norm_ck.cuh" +#include "contrib_ops/rocm/diffusion/group_norm_common.h" +#include "contrib_ops/rocm/diffusion/group_norm_impl.h" +#include "contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh" +#include "contrib_ops/rocm/diffusion/group_norm_triton.cuh" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using onnxruntime::rocm::GPU_WARP_SIZE; + +template +void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = DivUp(params->c, params->channels_per_block); + // The number of blocks to compute all the activations in a given instance. + grid.y = DivUp(params->hw, params->hw_per_block); + // The number of instances. + grid.z = params->n; + +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<StreamHandle()>>>( \ + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ + break; + + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { + case 256: + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) + case 128: + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template +Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { + dim3 grid; + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); + grid.z = params->n; + + GroupNormNHWCSumKernel + <<StreamHandle()>>>( + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); + return HIP_CALL(hipGetLastError()); +} + +template +void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = DivUp(params->c, params->channels_per_block); + // The number of blocks to compute all the activations in a given instance. + grid.y = DivUp(params->hw, params->hw_per_block); + // The number of instances. + grid.z = params->n; + +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<StreamHandle()>>>( \ + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ + params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ + params->hw, params->hw_per_block, params->use_silu); \ + break; + + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { + case 256: + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) + case 128: + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template +Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { + dim3 grid; + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); + grid.z = params->n; + + GroupNormNHWCScaleKernel + <<StreamHandle()>>>( + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, + params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, + params->use_silu); + return HIP_CALL(hipGetLastError()); +} + +template +class GroupNormNHWCOp { + public: + Status operator()(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); + auto status = GroupNormNHWCSumOp(params); + ORT_RETURN_IF_ERROR(status); + HIP_RETURN_IF_ERROR(hipGetLastError()); + status = GroupNormNHWCScaleOp(params); + ORT_RETURN_IF_ERROR(status); + HIP_RETURN_IF_ERROR(hipGetLastError()); + return Status::OK(); + } + + Status IsSupported(const GroupNormNHWCTunableParams* params) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), + "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, + ") isn't divisible by the number of vector size: ", VecSize); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && + params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), + "Configuration: Threads (", ThreadsPerBlock, "), vector size (", + VecSize, ") is redundant for the number of channels per group: ", + params->channels_per_block); + + return Status::OK(); + } +}; + +template +Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); + GroupNormNHWCSum(params); + HIP_RETURN_IF_ERROR(hipGetLastError()); + GroupNormNHWCScale(params); + HIP_RETURN_IF_ERROR(hipGetLastError()); + return Status::OK(); +} + +#define ADD_OP_FOR_ALL_VEC_SIZE(name, threads_per_block) \ + this->RegisterOp(name{}); \ + this->RegisterOp(name{}); \ + this->RegisterOp(name{}); + +#define ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 64) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 128) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 192) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 256) \ + ADD_OP_FOR_ALL_VEC_SIZE(name, 320) + +template +class GroupNormNHWCTunableOp : public TunableOp> { + public: + GroupNormNHWCTunableOp() { + this->RegisterOp(GroupNormNHWCStaticSelection); + ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) + +#ifdef USE_COMPOSABLE_KERNEL + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif // USE_COMPOSABLE_KERNEL + +#ifdef USE_TRITON_KERNEL + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#endif + } +}; + +#undef ADD_OP_FOR_ALL_VEC_SIZE +#undef ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc new file mode 100644 index 0000000000000..35427a02c631d --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/diffusion/nhwc_conv.cc @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/rocm/nn/conv.h" + +using namespace onnxruntime::rocm; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + NhwcConv, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Conv); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/fused_conv.cc b/onnxruntime/contrib_ops/rocm/fused_conv.cc new file mode 100644 index 0000000000000..4f3be98d97f80 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/fused_conv.cc @@ -0,0 +1,439 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/common/status.h" +#include "core/providers/rocm/nn/conv.h" +#include "core/providers/rocm/rocm_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +namespace { + +// Copied from hipDNN/library/src/hcc_detail/hipdnn_miopen.cpp +miopenStatus_t _miopenAddTensor( + miopenHandle_t handle, + const void* alpha, + const miopenTensorDescriptor_t aDesc, + const void* A, + const void* beta, + const miopenTensorDescriptor_t cDesc, + void* C, + const void* zero_scalar) { + const miopenTensorOp_t tensorOp = miopenTensorOpAdd; + // Using miopenOpTensor to implement Add operator. + // opnd2 = Add ( 0.0 * opnd0, alpha * opnd1 ) + beta * opnd2 + return miopenOpTensor(handle, tensorOp, + zero_scalar, cDesc, C, + alpha, aDesc, A, + beta, cDesc, C); +} + +} // namespace + +template +struct FNVHash { + uint32_t GetValue() const { return value_; } + + void Hash(const void* in_ptr, size_t nbytes) { + auto ptr = reinterpret_cast(in_ptr); + for (size_t i = 0; i < nbytes; ++i) { + value_ ^= ptr[i]; + value_ *= PRIME; + } + } + + template ::value, size_t>::type = 0> + FNVHash& operator<<(const T& pod) { + Hash(&pod, sizeof(pod)); + return *this; + } + + template + FNVHash& operator<<(const std::vector& pod_array) { + for (const auto& pod : pod_array) { + (*this) << pod; + } + return *this; + } + + void HashTensor(miopenTensorDescriptor_t tdesc) { + int size = 0; + miopenGetTensorDescriptorSize(tdesc, &size); + (*this) << size; + std::vector dims(size); + std::vector strides(size); + miopenDataType_t dtype; + miopenGetTensorDescriptor(tdesc, &dtype, dims.data(), strides.data()); + (*this) << dtype; + (*this) << dims; + (*this) << strides; + } + + void HashConvolutionDescriptor(miopenConvolutionDescriptor_t cdesc) { + int spatial_dim = 1; +#if ROCM_VERSION >= 50500 + MIOPEN_CALL(miopenGetConvolutionSpatialDim(cdesc, &spatial_dim)); + std::vector pads{spatial_dim}; + std::vector strides{spatial_dim}; + std::vector dilations{spatial_dim}; + miopenConvolutionMode_t mode; + MIOPEN_CALL(miopenGetConvolutionNdDescriptor(cdesc, spatial_dim, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)); +#else + // Previous versions of MIOpen doesn't provide API to probe the dimension of a + // miopenConvolutionDescriptor_t, so we have to guess. + // This algorithm is based on a specific behavior of miopenGetConvolutionNdDescriptor, + // which fails when requestedSpatialDim > the convolution's spatial dimension + constexpr const int kMaxSpatialDim = 5; + std::vector pads{kMaxSpatialDim}; + std::vector strides{kMaxSpatialDim}; + std::vector dilations{kMaxSpatialDim}; + miopenConvolutionMode_t mode; + bool spatial_dim_guessed = false; + for (int i = 0; i < kMaxSpatialDim; i++) { + if (miopenStatusSuccess == miopenGetConvolutionNdDescriptor( + cdesc, i, &spatial_dim, pads.data(), strides.data(), dilations.data(), &mode)) { + spatial_dim_guessed = true; + break; + } + } + ORT_ENFORCE(spatial_dim_guessed, "Failed to guess the actual spatial dimension"); + // Remove the extra dimension + pads.resize(spatial_dim); + strides.resize(spatial_dim); + dilations.resize(spatial_dim); +#endif + (*this) << spatial_dim; + (*this) << pads; + (*this) << strides; + (*this) << dilations; + (*this) << mode; + } + + private: + uint32_t value_ = BASIS; +}; + +template +class FusedConv : public onnxruntime::rocm::Conv { + public: + using Base = onnxruntime::rocm::Conv; + FusedConv(const OpKernelInfo& info) : onnxruntime::rocm::Conv(info) { + std::string activation; + ORT_THROW_IF_ERROR(info.GetAttr("activation", &activation)); + ORT_THROW_IF_ERROR(MapMode(activation)); + MIOPEN_CALL_THROW(miopenCreateActivationDescriptor(&activation_desc_)); + MIOPEN_CALL_THROW(miopenSetActivationDescriptor(activation_desc_, activation_mode_, 0.0, 0.0, 0.0)); + MIOPEN_CALL_THROW(miopenCreateOperatorArgs(&fusion_args_)); + } + + ORT_DISALLOW_COPY_AND_ASSIGNMENT(FusedConv); + + ~FusedConv() { + if (activation_desc_) { + MIOPEN_CALL_THROW(miopenDestroyActivationDescriptor(activation_desc_)); + activation_desc_ = nullptr; + } + + if (fusion_args_) { + miopenDestroyOperatorArgs(fusion_args_); + } + } + + Status ComputeInternal(OpKernelContext* context) const override { + std::lock_guard lock(Base::s_.mutex); + + ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); + if (Base::s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + + bool has_z = nullptr != Base::s_.z_data; + bool has_b = nullptr != Base::s_.b_data; + auto factory = [this](FusedConvFusionData& fusion) { + return this->DoCreateFusionDesc(this->Node().Name(), fusion); + }; + auto& cached_item = plan_cache_.FindOrCreateFusionPlanCache(Hash(), + factory); + bool should_try_fusion_api = cached_item.Validate(this->GetMiopenHandle(context)); + + typedef typename onnxruntime::rocm::ToHipType::MappedType HipT; + const auto alpha = onnxruntime::rocm::Consts::One; + const auto beta = onnxruntime::rocm::Consts::Zero; + IAllocatorUniquePtr workspace = Base::GetWorkSpace(context->GetComputeStream()); + miopenStatus_t fusion_status = miopenStatusNotInitialized; + + if (should_try_fusion_api) { + auto& fusion_info = *cached_item.fusion; + MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsConvForward(fusion_args_, + fusion_info.conv_op, + &alpha, + &beta, + Base::s_.w_data)); + if (has_z) { + MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, + fusion_info.bias_z_op, + &alpha, + &beta, + Base::s_.z_data)); + } + if (has_b) { + MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsBiasForward(fusion_args_, + fusion_info.bias_b_op, + &alpha, + &beta, + Base::s_.b_data)); + } + if (activation_desc_) { + const float relu_notused = 0.0; + MIOPEN_RETURN_IF_ERROR(miopenSetOpArgsActivForward(fusion_args_, + fusion_info.act_op, + &alpha, + &beta, + relu_notused, + relu_notused, + relu_notused)); + } + fusion_status = miopenExecuteFusionPlan(this->GetMiopenHandle(context), + fusion_info.plan, + Base::s_.x_tensor, + Base::s_.x_data, + Base::s_.y_tensor, + Base::s_.y_data, + fusion_args_); + } + if (miopenStatusSuccess != fusion_status) { + MIOPEN_RETURN_IF_ERROR(miopenConvolutionForward(this->GetMiopenHandle(context), + &alpha, + Base::s_.x_tensor, + Base::s_.x_data, + Base::s_.w_desc, + Base::s_.w_data, + Base::s_.conv_desc, + Base::s_.fwd_algo, + &beta, + Base::s_.y_tensor, + Base::s_.y_data, + workspace.get(), + Base::s_.workspace_bytes)); + if (has_b) { + MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), + &alpha, Base::s_.b_tensor, Base::s_.b_data, + &alpha, Base::s_.y_tensor, Base::s_.y_data, + &beta)); + } + if (has_z) { + MIOPEN_RETURN_IF_ERROR(_miopenAddTensor(this->GetMiopenHandle(context), + &alpha, Base::s_.z_tensor, Base::s_.z_data, + &alpha, Base::s_.y_tensor, Base::s_.y_data, + &beta)); + } + MIOPEN_RETURN_IF_ERROR(miopenActivationForward(this->GetMiopenHandle(context), + activation_desc_, + &alpha, + Base::s_.y_tensor, + Base::s_.y_data, + &beta, + Base::s_.y_tensor, + Base::s_.y_data)); + } + if (Base::s_.post_slicing_required) { + ORT_RETURN_IF_ERROR(onnxruntime::rocm::SliceOutUnwantedOutputSection( + this->Stream(context), + Base::s_.y_data, + Base::s_.y_dims_with_adjusted_pads, + Base::s_.Y->MutableDataRaw(), + Base::s_.y_dims.GetDims(), + Base::s_.slice_starts, + Base::s_.slice_ends, + Base::s_.slice_axes, + Base::s_.element_size)); + } + return Status::OK(); + } + + private: + Status MapMode(const std::string& activaton_mode) { + if (activaton_mode == "Relu") { + activation_mode_ = miopenActivationMode_t::miopenActivationRELU; + } else { + return ORT_MAKE_STATUS( + StatusCategory::ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, + "unsupported conv activation mode \"", activaton_mode, "\""); + } + return Status::OK(); + } + miopenActivationMode_t activation_mode_; + miopenActivationDescriptor_t activation_desc_ = nullptr; + + miopenOperatorArgs_t fusion_args_ = nullptr; + + // MIOpen Fusion API + // TODO: create one fusion descriptor shared by multiple FusedConv + // objects + // + // Considerations: + // How to determine two FusedConv objects may share the same fusion + // descriptor? Hashing x_tensor,conv_desc, etc.? + struct FusedConvFusionData { + miopenFusionPlanDescriptor_t plan = nullptr; + miopenFusionOpDescriptor_t conv_op = nullptr; + miopenFusionOpDescriptor_t bias_b_op = nullptr; + miopenFusionOpDescriptor_t bias_z_op = nullptr; + miopenFusionOpDescriptor_t act_op = nullptr; + + // TODO: There is a potential problem. miopenHandle_t may be destroyed and + // re-created later, sharing the same address. Currently there is any way + // to detect it? + mutable std::unordered_set compiled_on; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FusedConvFusionData); + + FusedConvFusionData() {} + ~FusedConvFusionData() { + if (plan) { + miopenDestroyFusionPlan(plan); + } + } + }; + + struct FusionPlanCacheItem { + std::unique_ptr fusion; + Status creation_result; + // TODO: Add a timestamp for eviction + // std::chrono::time_point last_access; + + FusionPlanCacheItem() {} + + miopenStatus_t CompileOnHandle(miopenHandle_t handle) const { + if (!fusion->plan) { + return miopenStatusNotInitialized; + } + auto iter = fusion->compiled_on.find(handle); + if (iter != fusion->compiled_on.end()) { + return miopenStatusSuccess; + } + auto ret = miopenCompileFusionPlan(handle, fusion->plan); + if (miopenStatusSuccess == ret) { + fusion->compiled_on.insert(handle); + } else { + return ret; + } + return miopenStatusSuccess; + } + + bool Validate(miopenHandle_t handle) const { + if (Status::OK() != creation_result) { + return false; + } + if (!fusion || !fusion->plan) { + return false; + } + auto compiling_status = CompileOnHandle(handle); + if (miopenStatusSuccess != compiling_status) { + return false; + } + + return true; + } + }; + + struct FusionPlanCache { + mutable std::mutex mutex; + using HashKey = uint32_t; + std::unordered_map cache_directory_; + + FusionPlanCache() { + } + + FusionPlanCacheItem& FindOrCreateFusionPlanCache(HashKey key, + std::function factory) { + std::lock_guard lock(mutex); + auto iter = cache_directory_.find(key); + if (iter == cache_directory_.end()) { + cache_directory_[key].fusion = std::make_unique(); + cache_directory_[key].creation_result = factory(*cache_directory_[key].fusion); + if (Status::OK() != cache_directory_[key].creation_result) { + cache_directory_[key].fusion.reset(); + } + } + return cache_directory_[key]; + } + }; + + static FusionPlanCache plan_cache_; + + Status DoCreateFusionDesc(const std::string& node_name, FusedConvFusionData& fusion) const { + bool has_z = nullptr != Base::s_.z_data; + bool has_b = nullptr != Base::s_.b_data; + MIOPEN_RETURN_IF_ERROR(miopenCreateFusionPlan(&fusion.plan, + miopenVerticalFusion, + Base::s_.x_tensor)); + auto status = miopenCreateOpConvForward(fusion.plan, &fusion.conv_op, Base::s_.conv_desc, Base::s_.w_desc); + if (status == miopenStatusUnsupportedOp) { + auto msg = MakeString("MIOpen does not support the conv fusion for node \"", + node_name, "\", fallback to unfused implementation."); + LOGS_DEFAULT(WARNING) << msg; + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, msg); + } + MIOPEN_RETURN_IF_ERROR(status); + + if (has_z) { + MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, + &fusion.bias_z_op, + Base::s_.z_tensor)); + } + if (has_b) { + MIOPEN_RETURN_IF_ERROR(miopenCreateOpBiasForward(fusion.plan, + &fusion.bias_b_op, + Base::s_.b_tensor)); + } + if (activation_desc_) { + MIOPEN_RETURN_IF_ERROR(miopenCreateOpActivationForward(fusion.plan, + &fusion.act_op, + activation_mode_)); + } + return Status::OK(); + } + + uint32_t Hash() const { + FNVHash hash; + bool has_z = nullptr != Base::s_.z_data; + bool has_b = nullptr != Base::s_.b_data; + hash.HashTensor(Base::s_.x_tensor); + hash.HashConvolutionDescriptor(Base::s_.conv_desc); + hash.HashTensor(Base::s_.w_desc); + if (has_z) { + hash.HashTensor(Base::s_.z_tensor); + } + if (has_b) { + hash.HashTensor(Base::s_.b_tensor); + } + if (activation_desc_) { + hash << static_cast(activation_mode_); + } + return hash.GetValue(); + } +}; + +template +typename FusedConv::FusionPlanCache FusedConv::plan_cache_; + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + FusedConv, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + FusedConv); + +REGISTER_KERNEL_TYPED(float); +REGISTER_KERNEL_TYPED(MLFloat16); +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu new file mode 100644 index 0000000000000..3539f32252944 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8.cu @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/common/float16.h" +#include "core/providers/rocm/rocm_kernel.h" +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; +using namespace onnxruntime::rocm::tunable::blas; + +class GemmFloat8 final : public RocmKernel { + public: + GemmFloat8(const OpKernelInfo& info) : RocmKernel(info) { + transA_ = info.GetAttrOrDefault("transA", 0); + transB_ = info.GetAttrOrDefault("transB", 0); + dtype_ = info.GetAttrOrDefault("dtype", onnx::TensorProto_DataType_FLOAT16); + alpha_ = info.GetAttrOrDefault("alpha", 1); + beta_ = info.GetAttrOrDefault("beta", 0); + } + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: +#if !defined(DISABLE_FLOAT8_TYPES) + template + Status ComputeFp8Fp16Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* scaleA, const Tensor* B, Tensor* C) const; + template + Status ComputeFp16Fp8Fp16(OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* B, const Tensor* scaleB, Tensor* C) const; + + template + [[nodiscard]] inline auto* GetOp() const { + using OpT = GemmFloat8TunableOp; + if (tunable_op_) { + return static_cast(tunable_op_.get()); + } + + auto create = std::make_unique(); // avoid new + tunable_op_ = std::shared_ptr(create.release(), [](void* ptr) { + auto release = std::unique_ptr(); // avoid delete + release.reset(static_cast(ptr)); + }); + + return static_cast(tunable_op_.get()); + } +#endif + + float alpha_; + float beta_; + bool transA_; + bool transB_; + int64_t dtype_; + + // fully type erased + mutable std::shared_ptr tunable_op_; +}; + +Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { +#if defined(DISABLE_FLOAT8_TYPES) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DISABLE_FLOAT8_TYPES"); +#else + const Tensor* A = ctx->Input(0); + const Tensor* B = ctx->Input(1); + const Tensor* C = ctx->Input(2); // bias + const Tensor* scale_a = ctx->Input(3); + const Tensor* scale_b = ctx->Input(4); + const Tensor* scale_y = ctx->Input(5); + + auto a_shape = A->Shape(); + auto b_shape = B->Shape(); + ORT_ENFORCE(a_shape.NumDimensions() == 2); + ORT_ENFORCE(b_shape.NumDimensions() == 2); + + auto m = !transA_ ? a_shape[0] : a_shape[1]; + auto k = !transA_ ? a_shape[1] : a_shape[0]; + ORT_ENFORCE(k == (!transB_ ? b_shape[0] : b_shape[1])); // k is compatible + auto n = !transB_ ? b_shape[1] : b_shape[0]; + + TensorShapeVector output_shape = {m, n}; + Tensor* Y = ctx->Output(0, output_shape); + + ORT_ENFORCE(!transA_, "ROCm GemmFloat8 does not support input A transpose"); + ORT_ENFORCE(dtype_ == onnx::TensorProto_DataType_FLOAT16, "ROCm GemmFloat8 only supports output float16"); + ORT_ENFORCE(C == nullptr, "ROCm GemmFloat8 does not support bias input"); + ORT_ENFORCE(scale_y == nullptr, "ROCm GemmFloat8 does not support output scaling"); + + if (A->IsDataType()) { + return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); + } else if (A->IsDataType()) { + return ComputeFp8Fp16Fp16(ctx, m, n, k, A, scale_a, B, Y); + } else if (B->IsDataType()) { + return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); + } else if (B->IsDataType()) { + return ComputeFp16Fp8Fp16(ctx, m, n, k, A, B, scale_b, Y); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unhandled type combination of GemmFloat8"); +#endif +} + +#if !defined(DISABLE_FLOAT8_TYPES) +template +Status GemmFloat8::ComputeFp8Fp16Fp16( + OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* scale_a, const Tensor* B, Tensor* C) const { + ORT_ENFORCE(A->IsDataType() && scale_a->IsDataType() && B->IsDataType()); + + onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; + params.tuning_ctx = GetTuningContext(); + params.stream = ctx->GetComputeStream(); + params.handle = GetHipblasHandle(ctx); + params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + + params.m = m; + params.n = n; + params.k = k; + + params.a = static_cast(A->DataRaw()); + params.lda = transA_ ? m : k; + params.scale_a = alpha_; + params.scale_a_dev = static_cast(scale_a->DataRaw()); + + params.b = static_cast(B->DataRaw()); + params.ldb = transB_ ? k : n; + params.scale_b = 1.0f; // NOTE: not used + params.scale_b_dev = nullptr; // NOTE: not used + + params.c = static_cast(C->MutableDataRaw()); + params.ldc = n; + params.scale_c = 1.0f; // NOTE: not implemented + params.scale_c_dev = nullptr; // NOTE: not implemented + + if (!transA_ && !transB_) { + return (*GetOp())(¶ms); + } else if (transA_ && !transB_) { + ORT_NOT_IMPLEMENTED("transA is not implemented"); + } else if (!transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transB is not implemented"); + } else if (transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); +} + +template +Status GemmFloat8::ComputeFp16Fp8Fp16( + OpKernelContext* ctx, int64_t m, int64_t n, int64_t k, + const Tensor* A, const Tensor* B, const Tensor* scale_b, Tensor* C) const { + ORT_ENFORCE(A->IsDataType() && B->IsDataType() && scale_b->IsDataType()); + + onnxruntime::rocm::tunable::blas::GemmFloat8Params params{}; + params.tuning_ctx = GetTuningContext(); + params.stream = ctx->GetComputeStream(); + params.handle = GetHipblasHandle(ctx); + params.opa = transA_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + params.opb = transB_ ? tunable::blas::BlasOp::Trans : tunable::blas::BlasOp::NonTrans; + + params.m = m; + params.n = n; + params.k = k; + + params.a = static_cast(A->DataRaw()); + params.lda = transA_ ? m : k; + params.scale_a = 1.0f; // NOTE: not used + params.scale_a_dev = nullptr; // NOTE: not used + + params.b = static_cast(B->DataRaw()); + params.ldb = transB_ ? k : n; + params.scale_b = alpha_; + params.scale_b_dev = static_cast(scale_b->DataRaw()); + + params.c = static_cast(C->MutableDataRaw()); + params.ldc = n; + params.scale_c = 1.0f; // NOTE: not implemented + params.scale_c_dev = nullptr; // NOTE: not implemented + + if (!transA_ && !transB_) { + return (*GetOp())(¶ms); + } else if (transA_ && !transB_) { + ORT_NOT_IMPLEMENTED("transA is not implemented"); + } else if (!transA_ && transB_) { + return (*GetOp())(¶ms); + } else if (transA_ && transB_) { + ORT_NOT_IMPLEMENTED("transA & transB is not implemented"); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unreachable"); +} +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#else +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#endif + +ONNX_OPERATOR_KERNEL_EX( + GemmFloat8, + kMSDomain, + 1, + kRocmExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) + .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) + .TypeConstraint("TR", BuildKernelDefConstraints()) + .TypeConstraint("TS", BuildKernelDefConstraints()), + GemmFloat8); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh new file mode 100644 index 0000000000000..b545eb1f2a149 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck.cuh @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#if defined(USE_COMPOSABLE_KERNEL) + +#include "core/providers/rocm/composable_kernel_common.h" + +#include "ck/ck.hpp" +#include "ck/utility/functional3.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#endif + +#if !defined(DISABLE_FLOAT8_TYPES) +#include "core/common/float8.h" +#endif +#include "core/providers/rocm/tunable/gemm_common.h" + +namespace onnxruntime { +namespace rocm { +namespace tunable { + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +constexpr bool always_false = false; + +template +struct Scale { + constexpr const static bool is_pack2_invocable = true; + constexpr const static bool is_pack4_invocable = true; + + explicit Scale(float scale_value, const float* dev_scale_ptr) : scale_value_{scale_value}, dev_scale_ptr_{dev_scale_ptr} {} + + template + __forceinline__ __host__ __device__ Y fast_type_convert(X x) const { + static_assert(always_false, "not implemented"); + (void)x; + } + + template <> + __forceinline__ __host__ __device__ ck::half_t fast_type_convert(ck::f8_t x) const { + // https://github.com/ROCmSoftwarePlatform/triton/blob/0cc3f8b84a16892396f6e08a04991034d67e32b1/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L220-L233 + constexpr const uint16_t mask = 0x7fff; + constexpr const uint16_t sign_mask = 0x8000; + constexpr const uint16_t exp_compensate = []() { + if constexpr (std::is_same_v) { + return 0x2000; + } else if constexpr (std::is_same_v) { + return 0x1c00; + } + }(); + + uint8_t x_u8 = reinterpret_cast(x); + uint16_t x_u16 = static_cast(x_u8) << 8; + uint16_t exp = (x_u16 & mask) >> 1; + uint16_t y = (x_u16 & sign_mask) | (exp + exp_compensate); + return reinterpret_cast(y); + } + + __forceinline__ __host__ __device__ void operator()(ck::half_t& y, const ck::f8_t& x) const { + float scale = scale_value_ * (*dev_scale_ptr_); + y = ck::type_convert(scale * fast_type_convert(x)); + } + + __forceinline__ __host__ __device__ void operator()(ck::half2_t& ys, const ck::f8x2_t& xs) const { + float scale = scale_value_ * (*dev_scale_ptr_); + constexpr const uint32_t mask = 0x7fff7fff; + constexpr const uint32_t sign_mask = 0x80008000; + constexpr const uint32_t exp_compensate = []() { + if constexpr (std::is_same_v) { + return 0x20002000; + } else if constexpr (std::is_same_v) { + return 0x1c001c00; + } + }(); + + const uchar2& x2_u8 = reinterpret_cast(xs); + uchar4 x{0, x2_u8.x, 0, x2_u8.y}; + uint32_t x_u32 = reinterpret_cast(x); + + uint32_t exp = (x_u32 & mask) >> 1; + uint32_t v = (x_u32 & sign_mask) | (exp + exp_compensate); + ys = scale * reinterpret_cast(v); + } + + __forceinline__ __host__ __device__ void operator()(ck::half4_t& ys, const ck::f8x4_t& xs) const { + float scale = scale_value_ * (*dev_scale_ptr_); + constexpr const uint32_t mask = 0x7fff7fff; + constexpr const uint32_t sign_mask = 0x80008000; + constexpr const uint32_t exp_compensate = []() { + if constexpr (std::is_same_v) { + return 0x20002000; + } else if constexpr (std::is_same_v) { + return 0x1c001c00; + } + }(); + + uint32_t xs_u32 = reinterpret_cast(xs); + uint32_t x_u32_0 = __byte_perm(xs_u32, 0, 0x1504); + uint32_t x_u32_1 = __byte_perm(xs_u32, 0, 0x3726); + uint32_t exp_0 = (x_u32_0 & mask) >> 1; + uint32_t exp_1 = (x_u32_1 & mask) >> 1; + uint32_t v_0 = (x_u32_0 & sign_mask) | (exp_0 + exp_compensate); + uint32_t v_1 = (x_u32_1 & sign_mask) | (exp_1 + exp_compensate); + uint64_t v = v_0 | uint64_t(v_1) << 32; + ys = scale * reinterpret_cast(v); + } + + float scale_value_; + const float* const dev_scale_ptr_; +}; +#endif + +namespace blas { + +template +struct GemmFloat8Params : tunable::OpParams { + std::string Signature() const override { + return MakeString(BlasOpToString(opa), BlasOpToString(opb), "_", m, "_", n, "_", k); + } + + hipblasHandle_t handle; + BlasOp opa; + BlasOp opb; + int64_t m; + int64_t n; + int64_t k; + float scale_a{}; + const float* scale_a_dev{}; + const TA* a; + int64_t lda; + float scale_b{}; + const float* scale_b_dev{}; + const TB* b; + int64_t ldb; + TC* c; + float scale_c{}; + const float* scale_c_dev{}; + int64_t ldc; +}; + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Nop = ck::tensor_operation::element_wise::PassThrough; + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, Nop, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, Nop, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, Nop>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, Nop>>>& instances); + +template +auto CreateOp(float scale, const float* dev_scale) { + if constexpr (std::is_same_v) { + return Scale(scale, dev_scale); + } else if constexpr (std::is_same_v) { + return Scale(scale, dev_scale); + } else { + return Nop{}; + } +} + +template +auto GetCKF8SplitKGemmTypeStringAndOps() { + using CKTA = typename CKDataTypeAdaptor::type; + using CKTB = typename CKDataTypeAdaptor::type; + using CKTC = typename CKDataTypeAdaptor::type; + + using CKLayoutA = typename CKBlasOpAdaptor::type; + using CKLayoutB = typename CKBlasOpAdaptor::type; + + using OpA = std::conditional_t, Scale, Nop>; + using OpB = std::conditional_t, Scale, Nop>; + using OpC = std::conditional_t, Scale, Nop>; + + using DeviceGemm = ck::tensor_operation::device::DeviceGemmSplitK< + CKLayoutA, CKLayoutB, Row, + CKTA, CKTB, CKTC, + OpA, OpB, OpC>; + + std::vector>>> ret; + + for (auto num_split : {1, 4, 16, 64}) { + std::vector> instances{}; + if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { + add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(instances); + } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { + add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(instances); + } else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) { + add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(instances); + } else { + static_assert(always_false, "no instances for the type combination"); + LOGS_DEFAULT(FATAL) << "no instances for the type combination"; + } + for (auto&& impl : instances) { + auto type_string = std::to_string(ret.size()) + "_" + impl->GetTypeString() + "_SplitK" + std::to_string(num_split); + auto invoker = impl->MakeInvokerPointer(); + auto ck_gemm_op = [num_split, impl = std::move(impl), invoker = std::move(invoker)](const GemmFloat8Params* params) -> Status { + OpA op_a = CreateOp(params->scale_a, params->scale_a_dev); + OpB op_b = CreateOp(params->scale_b, params->scale_b_dev); + OpC op_c = CreateOp(params->scale_c, params->scale_c_dev); + + auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, + params->m, params->n, params->k, + params->lda, params->ldb, params->ldc, + op_a, op_b, op_c, num_split); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support ", params->Signature()); + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(ck_gemm_op))); + } + } + return ret; +} + +#endif // USE_COMPOSABLE_KERNEL + +template +class GemmFloat8TunableOp : public TunableOp> { + public: + GemmFloat8TunableOp() { +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + for (auto&& [_, op] : GetCKF8SplitKGemmTypeStringAndOps()) { + ORT_UNUSED_PARAMETER(_); + this->RegisterOp(std::move(op)); + } +#else + ORT_ENFORCE(false, "CK is required to support GemmFloat8 computing"); +#endif // USE_COMPOSABLE_KERNEL + } +}; + +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu new file mode 100644 index 0000000000000..4c691dd18f2e9 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/add_instance.cu @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +namespace internal { +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances); +} // namespace internal + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck(instances); + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort(instances); +} + +namespace internal { +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances); + +// TODO: The first try of derivation does not going well due to various constraints. +// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( +// std::vector, PassThrough, PassThrough>>>& instances); + +// TODO: The first try of derivation does not going well due to various constraints. +// void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort( +// std::vector, PassThrough, PassThrough>>>& instances); +} // namespace internal + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, PassThrough, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); + // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: +} + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector, PassThrough, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck(instances); + // internal::add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ort(instances); // TODO: +} + +namespace internal { +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances); +} // namespace internal + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector, PassThrough>>>& instances) { + internal::add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck(instances); +} + +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu new file mode 100644 index 0000000000000..49463e58886f8 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cu @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + +// The derived version is simply double BBlockTransferSrcScalarPerVector and adjust other values correspondingly +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 8, 4, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 8, 4, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 8, 4, 32, 32, 3, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 8, 4, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 8, 4, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 8, 4, 32, 32, 1, 3, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 12, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 16, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 8, 4, 32, 32, 3, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 8, 4, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 8, 4, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 8, 4, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ort{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_generic{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu new file mode 100644 index 0000000000000..236e5555051fc --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance_original.cu @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances_ck{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu new file mode 100644 index 0000000000000..1a0d45df82a71 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cu @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, Scale, PassThrough, GemmMNPadding, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); +} + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_ck( + std::vector, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances_generic{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu new file mode 100644 index 0000000000000..a0628802ec09e --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/math/gemm_float8_ck_impl/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance_original.cu @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Modifications Copyright (c) Microsoft. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#if defined(USE_COMPOSABLE_KERNEL) && !defined(DISABLE_FLOAT8_TYPES) + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "contrib_ops/rocm/math/gemm_float8_ck.cuh" + +namespace onnxruntime { +namespace rocm { +namespace tunable { +namespace blas { +namespace internal { + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle; + +template +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, Scale, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); +} + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck( + std::vector, PassThrough, PassThrough>>>& instances) { + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_ck{}); + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances_generic{}); +} + +} // namespace internal +} // namespace blas +} // namespace tunable +} // namespace rocm +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc new file mode 100644 index 0000000000000..7dbb24463961e --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -0,0 +1,347 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/rocm/rocm_common.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace rocm { +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasAdd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Irfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Irfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Irfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ComplexMulConj); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasDropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskDropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskBiasDropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, NGramRepeatBlock); + +// These ops were experimental ops in onnx domain which have been removed now. We add them here as +// contrib ops to maintain backward compatibility +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Affine); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Affine); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Crop); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ImageScaler); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ImageScaler); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, LongformerAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, LongformerAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Sampling); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_float, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, double_double_double, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_MLFloat16, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Inverse); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GemmFloat8); + +#ifdef ENABLE_ATEN +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); +#endif + +#ifdef ENABLE_TRAINING_OPS +// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or +// 2). this is needed by inference for other purpose. +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather); +#endif + +#ifdef ORT_USE_NCCL +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll); +#endif + +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} + +// clang-format off +Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // These ops were experimental ops in onnx domain which have been removed now. We add them here as + // contrib ops to maintain backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // TransposedMatMul is still here for backward compatibility + BuildKernelCreateInfo, // backward compatibility + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + +#ifdef ENABLE_ATEN + BuildKernelCreateInfo, +#endif + +#ifdef ENABLE_TRAINING_OPS + // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or + // 2). this is needed by inference for other purpose. + BuildKernelCreateInfo, +#endif + +#ifdef ORT_USE_NCCL + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif + + }; + + for (auto& function_table_entry : function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); + } + } + + return Status::OK(); +} +// clang-format on + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h new file mode 100644 index 0000000000000..db9a5d4fcd83e --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +Status RegisterRocmContribKernels(KernelRegistry& kernel_registry); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 130dd0c25a880..a5ab63d74df24 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -165,7 +165,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let query_pos = m + local_id.y + past_sequence_length;\n" << " let key_pos = n + local_id.x;\n" << " if (key_pos > query_pos) {\n" - << " sum = -3.4028234663852886e+38; // Set to very negative value for masking\n" + << " sum = -3.40282e+38; // Set to very negative value for masking\n" << " }\n"; } @@ -272,7 +272,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let effective_seq_length = seq_causal_length;\n"; } shader.MainFunctionBody() - << "var thread_max_vector = f32_val_t(-3.4028234663852886e+38f);\n" + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" << " let actual_pos = local_offset + i + start_offset;\n" << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" @@ -289,7 +289,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } else if (use_smooth_softmax_) { shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n"; } else { - shader.MainFunctionBody() << "var max_value = f32(-3.4028234663852886e+38f);\n"; + shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n"; } shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 2a67dfdb07912..606dbfde15c2c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -421,7 +421,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co indirect_buffer_ptr, tile_size)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr)); } if (parameters.sequence_length_ > 1) { @@ -571,8 +571,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {static_cast(head_size_vec)}, - {static_cast(half_rotary_embedding_dim_vec)}, + {head_size_vec}, + {half_rotary_embedding_dim_vec}, {present_sequence_length}, {tile_size}, {static_cast(dispatch_size)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index ff8e4ecc08bab..a5922ec9512fd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -26,7 +26,7 @@ fn get_total_sequence_length() -> u32 { #if is_fp16 const min_value = q_element_t(-65504.0); #else -const min_value = q_element_t(-3.4028234663852886e+38f); +const min_value = q_element_t(-3.402823e+38f); #endif // For max performance max_k_step should be the same as sg_size, however we might run out of registers diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template index ac9a157492007..c6f768beffa0f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template @@ -93,7 +93,7 @@ $MAIN { if (local_idx == 0u) { // Calculate the max and sum in current split. - var l_max = f32(-3.4028234663852886e+38f); + var l_max = f32(-3.402823e+38f); var l_sum = f32(0); for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { l_max = max(l_max, f32(tile_qk[i])); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template index a113e96130985..37cf7e8f11b1f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template @@ -54,7 +54,7 @@ $MAIN { // Calculate the global max and sum in qk. if (head_idx < uniforms.num_heads) { - var g_max = f32(-3.4028234663852886e+38f); + var g_max = f32(-3.402823e+38f); var g_sum = f32(0); for (var i = 0u; i < num_total_seq_length_tile; i++) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 416a895e61745..05717fd2fe686 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -128,8 +128,8 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {static_cast(head_size_vec)}, - {static_cast(half_rotary_embedding_dim_vec)}, + {head_size_vec}, + {half_rotary_embedding_dim_vec}, {static_cast(dispatch_size)}, }) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); diff --git a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template index 6e0d4c7299793..1214777009a8d 100644 --- a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template @@ -18,7 +18,7 @@ const K: u32 = k; #if is_fp16 const MAX_FLOAT: f16 = 65504.0; #else -const MAX_FLOAT: f32 = 3.4028234663852886e+38; +const MAX_FLOAT: f32 = 3.402823466e+38; #endif var shared_vals: array; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 1c80d83f99feb..e77496b6e8196 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -499,7 +499,8 @@ class PlannerImpl { /*! \brief Given a tensor-type, return the size of an element of the tensor. */ static size_t GetElementSize(const DataType& tensor_type) { - MLDataType ml_data_type = DataTypeImpl::GetDataType(*tensor_type); + const TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); + MLDataType ml_data_type = DataTypeImpl::TypeFromProto(type_proto); const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); ORT_ENFORCE(nullptr != tensor_type_base); MLDataType elt_type = tensor_type_base->GetElementType(); diff --git a/onnxruntime/core/framework/ort_value_name_idx_map.h b/onnxruntime/core/framework/ort_value_name_idx_map.h index 6035dc4e85242..76e7e369514d4 100644 --- a/onnxruntime/core/framework/ort_value_name_idx_map.h +++ b/onnxruntime/core/framework/ort_value_name_idx_map.h @@ -33,7 +33,7 @@ class OrtValueNameIdxMap { common::Status GetIdx(std::string_view name, int& idx) const { idx = -1; - auto it = map_.find(name); + auto it = map_.find(std::string(name)); if (it == map_.end()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not find OrtValue with name '", name, "'"); } diff --git a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h index 94ef87fb069af..bc52a45adfd43 100644 --- a/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h +++ b/onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h @@ -83,8 +83,7 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nchw_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input if (rank < 3) { - *nhwc_tp.mutable_tensor_type()->mutable_shape() = nchw_shape; - return; + fail_shape_inference("Output tensor must have at least 3 dimensions"); } // Convert output shape from N, C, H {, W, ...} to N, H {, W, ...}, C. @@ -106,8 +105,8 @@ class NhwcInferenceContext : public ONNX_NAMESPACE::InferenceContext { const int rank = nhwc_shape.dim_size(); // N and C dims are required. Some operators like AveragePool allow 1D input. if (rank < 3) { - *nchw_tp.mutable_tensor_type()->mutable_shape() = nhwc_shape; - return; + fail_shape_inference( + "Tensor must have at least 3 dimensions to convert between channels first and channels last."); } // Convert input shape from {N, H, W, ..., C} to {N, C, H, W, ...}. diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 1eb03af3befa4..6cbbdd4e0a7ef 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -81,10 +81,6 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(captureState); } -void Telemetry::LogCompileModel(uint32_t session_id) const { - ORT_UNUSED_PARAMETER(session_id); -} - void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { ORT_UNUSED_PARAMETER(session_id); diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index 9c2859f7634b6..b60345e1b8a80 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -66,8 +66,6 @@ class Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const; - virtual void LogCompileModel(uint32_t session_id) const; - virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 693e265af46b1..2e5d334856278 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -334,20 +334,6 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio } } -void WindowsTelemetry::LogCompileModel(uint32_t session_id) const { - if (global_register_count_ == 0 || enabled_ == false) - return; - - TraceLoggingWrite(telemetry_provider_handle, - "CompileModel", - TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), - TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), - TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), - // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), - TraceLoggingUInt32(session_id, "sessionId")); -} - void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const { if (global_register_count_ == 0 || enabled_ == false) diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 044feec071223..261d14a7fed8c 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -59,8 +59,6 @@ class WindowsTelemetry : public Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const override; - void LogCompileModel(uint32_t session_id) const override; - void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 26144e6ba3995..ef977161bcc37 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -126,7 +126,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation -JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.4028234663852886e+38f, max, -3.4028234663852886e+38f) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index d148c4191d5d7..e2a8005aba1da 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1407,30 +1407,9 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } // Find inputs and outputs of the subgraph - std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs; - - // These maps store the inputs and outputs of the subgraph. - // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration - // to determine the final inputs and outputs of the subgraph. - std::unordered_map fused_inputs, fused_outputs; - - // This map stores the node's output that will be consumed by another node outside of this subgraph. - // So the node's output should be put into the subgraph's output list. - std::unordered_map fused_outputs_to_add; - - // This map stores the node's output that is original graph's output. - // So the node's output should be put into the subgraph's output list. - std::unordered_map graph_outputs_to_add; - + std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; - - // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', - // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. - // Items added earlier receive a smaller order index than items added later. - // When constructing the final sub_graph's input or output lists, entries with smaller - // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -1449,7 +1428,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs.insert({input, input_order++}); + fused_inputs[input] = input_order++; } } @@ -1464,7 +1443,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs.insert({input, input_order++}); + fused_inputs[input] = input_order++; } } @@ -1485,32 +1464,38 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - - if (node_set.find(node_idx) == node_set.end()) { - // This output will be consumed by another node outside of this subgraph. - // So the output should be put into the subgraph's output list. - fused_outputs_to_add.insert({output, output_order++}); + if (node_set.find(node_idx) != node_set.end()) { + const auto& iter = fused_inputs.find(output); + if (iter != fused_inputs.end()) { + fused_inputs.erase(iter); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } + fused_outputs[output] = output_order++; + } + } else { + fused_outputs_to_add[output] = output_order++; } } - } - - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - // Only when output is neither in input list nor erased list, - // and the output is consumed by another node, add the output to output list - fused_outputs.insert({output, output_order++}); + } else { + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); } - } + // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list + else if (erased.find(output) == erased.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - // This output is the graph's output. - // So the output should be put into the subgraph's output list. - graph_outputs_to_add.insert({output, output_order++}); + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + fused_outputs[output] = output_order++; + } + } } } } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 0bb3accb4d754..4d183b95bd938 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -76,9 +76,6 @@ Status BaseOpBuilder::ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper, return CheckHtpDataTypes(input_qnn_dtypes, output_qnn_dtypes); } else if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) { return CheckGpuDataTypes(input_qnn_dtypes, output_qnn_dtypes); - } else if (IsIrBackend(qnn_model_wrapper.GetQnnBackendType())) { - // TODO: CheckIrDataTypes - return Status::OK(); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Only support backend: CPU, HTP and GPU"); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index 9f28e2609faa1..f3d81d7d2fdd7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -574,10 +574,6 @@ bool QnnOpConfigWrapper::CreateQnnGraphOp(const QNN_INTERFACE_VER_TYPE& qnn_inte return true; } -bool IsIrBackend(QnnBackendType backend_type) { - return backend_type == QnnBackendType::SERIALIZER; -} - bool IsNpuBackend(QnnBackendType backend_type) { return backend_type == QnnBackendType::HTP || backend_type == QnnBackendType::DSP; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 77508f3934a20..42f4d7bb60f34 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -96,8 +96,6 @@ enum class QnnBackendType : uint8_t { SERIALIZER, }; -bool IsIrBackend(QnnBackendType backend_type); - bool IsCpuBackend(QnnBackendType backend_type); bool IsNpuBackend(QnnBackendType backend_type); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 8973a4efa8ba1..85901ab6fdfec 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -222,14 +222,14 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); + const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { - const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name() + ". " + result.ErrorMessage(); + const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name(); LOGS(logger, ERROR) << message; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index e5b48da33fbc3..cd0c0e4bffdb5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2035,30 +2035,9 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } // Find inputs and outputs of the subgraph - std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs; - - // These maps store the inputs and outputs of the subgraph. - // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration - // to determine the final inputs and outputs of the subgraph. - std::unordered_map fused_inputs, fused_outputs; - - // This map stores the node's output that will be consumed by another node outside of this subgraph. - // So the node's output should be put into the subgraph's output list. - std::unordered_map fused_outputs_to_add; - - // This map stores the node's output that is original graph's output. - // So the node's output should be put into the subgraph's output list. - std::unordered_map graph_outputs_to_add; - + std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; - - // This is the relative ordering that ensures node's input or output being added to the 'fused_inputs', - // 'fused_outputs', 'fused_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. - // Items added earlier receive a smaller order index than items added later. - // When constructing the final sub_graph's input or output lists, entries with smaller - // order indices will appear before those with larger indices. int input_order = 0; int output_order = 0; @@ -2077,7 +2056,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs.insert({input, input_order++}); + fused_inputs[input] = input_order++; } } @@ -2092,7 +2071,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs.insert({input, input_order++}); + fused_inputs[input] = input_order++; } } @@ -2113,32 +2092,38 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } else { output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - static_cast(it->GetNode().InputDefs().size())]; } - - if (node_set.find(node_idx) == node_set.end()) { - // This output will be consumed by another node outside of this subgraph. - // So the output should be put into the subgraph's output list. - fused_outputs_to_add.insert({output, output_order++}); + if (node_set.find(node_idx) != node_set.end()) { + const auto& iter = fused_inputs.find(output); + if (iter != fused_inputs.end()) { + fused_inputs.erase(iter); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } + fused_outputs[output] = output_order++; + } + } else { + fused_outputs_to_add[output] = output_order++; } } - } - - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } else if (erased.find(output) == erased.end()) { - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - // Only when output is neither in input list nor erased list, - // and the output is consumed by another node, add the output to output list - fused_outputs.insert({output, output_order++}); + } else { + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); } - } + // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list + else if (erased.find(output) == erased.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - // This output is the graph's output. - // So the output should be put into the subgraph's output list. - graph_outputs_to_add.insert({output, output_order++}); + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + fused_outputs[output] = output_order++; + } + } } } } diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc index 9948069c6779b..85096d0e262d7 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/vsinpu/builders/impl/clip_op_builder.cc @@ -78,8 +78,8 @@ bool ClipOpBuilder::HandleBuildOp(vsi::npu::GraphEP* graph_ep, LOGS_DEFAULT(INFO) << "Creating Clip Op."; if (node_unit.SinceVersion() <= 6) { NodeAttrHelper helper(node_unit.GetNode()); - auto min = helper.Get("min", -3.4028234663852886e+38f); - auto max = helper.Get("max", 3.4028234663852886e+38f); + auto min = helper.Get("min", -3.402e+38f); + auto max = helper.Get("max", 3.402e+38f); auto op = graph_ep->GetGraph()->CreateOperation(min, max); (*op).BindInputs(inputs).BindOutputs(outputs); graph_ep->GetOps().push_back(std::move(op)); diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 3e1b87821fe2f..b3eb4b5061423 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -13,7 +13,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, - WebGpuDevice, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), OrtMemTypeDefault)), buffer_manager_{buffer_manager}, mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 74b3d669fcf3b..7c38b4557e078 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -11,11 +11,6 @@ namespace webgpu { class BufferManager; -inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, - OrtDevice::MemType::DEFAULT, - OrtDevice::VendorIds::NONE, - 0}; - class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index d1a2011c8e191..ebe71c6ccfacd 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -6,25 +6,22 @@ namespace onnxruntime { namespace webgpu { - -ComputeContextBase::ComputeContextBase(WebGpuContext& webgpu_context, - const WebGpuExecutionProvider& ep, - const OpKernel& op_kernel) +ComputeContext::ComputeContext(OpKernelContext& kernel_context, + const OpKernel& op_kernel, + const WebGpuExecutionProvider& ep, + WebGpuContext& webgpu_context) : webgpu_context_{webgpu_context}, - ep_{ep}, - op_kernel_{op_kernel} { + kernel_context_{kernel_context}, + op_kernel_{op_kernel}, + ep_{ep} { } -const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(const ComputeContextBase& context) { +const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const ComputeContext& context) { return context.ep_.BufferManager(); } -ComputeContext::ComputeContext(WebGpuContext& webgpu_context, - const WebGpuExecutionProvider& ep, - const OpKernel& op_kernel, - OpKernelContext& kernel_context) - : ComputeContextBase(webgpu_context, ep, op_kernel), - kernel_context_{kernel_context} { +const SplitKConfig& ComputeContext::GetSplitKConfig() { + return webgpu_context_.GetSplitKConfig(); } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index fdf89854469d6..ed16f2f0a1345 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -24,13 +24,7 @@ namespace webgpu { class WebGpuContext; class BufferManager; -// -// Class ComputeContextBase is designed to provide basic context information -// for running a compute shader program. -// -// An instance of ComputeContextBase does not depend on OpKernelContext, which needs an execution frame to be created. -// -class ComputeContextBase { +class ComputeContext final { public: // Nested accessor class to provide controlled access to BufferManager class BufferManagerAccessor { @@ -40,31 +34,18 @@ class ComputeContextBase { friend class WebGpuContext; private: - static const webgpu::BufferManager& Get(const ComputeContextBase& context); + static const webgpu::BufferManager& Get(const ComputeContext& context); }; - ComputeContextBase(WebGpuContext& webgpu_context, - const WebGpuExecutionProvider& ep, - const OpKernel& op_kernel); - - ~ComputeContextBase() = default; - - // - // Get the node name. - // - inline decltype(auto) NodeName() const { - return op_kernel_.Node().Name(); - } + ComputeContext(OpKernelContext& kernel_context, + const OpKernel& op_kernel, + const WebGpuExecutionProvider& ep, + WebGpuContext& webgpu_context); - // - // Get the operator type. - // - inline decltype(auto) OpType() const { - return op_kernel_.Node().OpType(); - } + ~ComputeContext() = default; // - // Get various information from the WebGPU context. + // Get various information from the context. // inline const wgpu::AdapterInfo& AdapterInfo() const { @@ -76,6 +57,9 @@ class ComputeContextBase { inline bool HasFeature(wgpu::FeatureName feature) const { return webgpu_context_.DeviceHasFeature(feature); } + inline bool IsGraphCaptureEnabled() const { + return ep_.IsGraphCaptureEnabled(); + } #if !defined(__wasm__) inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const { return webgpu_context_.SubgroupMatrixConfigs(); @@ -83,57 +67,17 @@ class ComputeContextBase { #endif // - // Get Split-K configuration. - // - inline const SplitKConfig& GetSplitKConfig() const { - return webgpu_context_.GetSplitKConfig(); - } - - // - // Get whether graph capture is enabled. + // Get the kernel context. // - inline bool IsGraphCaptureEnabled() const { - return ep_.IsGraphCaptureEnabled(); + inline OpKernelContext& KernelContext() { + return kernel_context_; } // // Get the logger. // inline const logging::Logger& Logger() const { - return *ep_.GetLogger(); - } - - // - // Run a compute shader program. - // - inline Status RunProgram(const ProgramBase& program) { - return webgpu_context_.Run(*this, program); - } - - protected: - WebGpuContext& webgpu_context_; - const WebGpuExecutionProvider& ep_; - const OpKernel& op_kernel_; -}; - -// -// Class ComputeContext provides all information a `ComputeContextBase` provides, and also -// access to `OpKernelContext` for input and output tensors. -// -class ComputeContext final : public ComputeContextBase { - public: - ComputeContext(WebGpuContext& webgpu_context, - const WebGpuExecutionProvider& ep, - const OpKernel& op_kernel, - OpKernelContext& kernel_context); - - ~ComputeContext() = default; - - // - // Get the kernel context. - // - inline OpKernelContext& KernelContext() { - return kernel_context_; + return kernel_context_.Logger(); } // @@ -201,8 +145,25 @@ class ComputeContext final : public ComputeContextBase { return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst); } + // + // Run a compute shader program. + // + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } + + // + // Get Split-K configuration. + // + // `split_k_config_` won't be initialized until the first call to this method. + // + const SplitKConfig& GetSplitKConfig(); + private: + WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; + const OpKernel& op_kernel_; + const WebGpuExecutionProvider& ep_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 3c974ef5133c0..82645e30082e6 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -322,14 +322,11 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { round_str = "round"; } - std::string use_pow_shortcut; + std::string use_sqrt_for_pow; if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { - // use multiplication instead of pow when base (a) is a float and exponent (b) is 2.0 // use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5 - use_pow_shortcut = - " else if (b == 2.0) {\n" - " return a * a;\n" - " } else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" + use_sqrt_for_pow = + " else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" " return sqrt(a);\n" " }\n"; } @@ -340,7 +337,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { " } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n" " return input_a_element_t(pow(f32(a), b)); // NaN\n" " }\n" - << use_pow_shortcut + << use_sqrt_for_pow << " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n" << "}\n" "fn pow_v(a : vec4, b : vec4) -> vec4 {\n" diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index c26b58a7af1f4..6aefa90a59285 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -93,21 +93,18 @@ Status ApplyGemmPacked(const Tensor* a, } const uint32_t TILE_SIZE = 32; - const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE; program.CacheHint(alpha, transA, transB, c_is_scalar) .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) - .SetDispatchGroupSize(dispatch_x, dispatch_y, 1u) + .SetDispatchGroupSize(num_tile_n, num_tile_m, 1) .SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z) .AddUniformVariables({{alpha}, {beta}, - {M}, /* dim_a_outer */ - {N}, /* dim_b_outer */ - {K}, /*dim_inner */ - {dispatch_x}, /* logical_dispatch_x */ - {dispatch_y}, /* logical_dispatch_y */ - {1u}} /* logical_dispatch_z */ + {M}, /* dim_a_outer */ + {N}, /* dim_b_outer */ + {K}} /*dim_inner */ ); return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.h b/onnxruntime/core/providers/webgpu/math/gemm_packed.h index cb89ccefba313..dce5164693aa8 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.h @@ -32,10 +32,7 @@ class GemmProgram final : public Program { {"beta", ProgramUniformVariableDataType::Float32}, {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}); constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 89718149cea88..7cbc7f6a4a821 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -117,20 +117,6 @@ void HandleMatMulWithSplitK( } } -// Compute `logical_workgroup_id` and `logical_global_id` because the dispatch workgroup size in -// `ProgramBase.SetDispatchGroupSize()` may be normalized in -// `ProgramManager::NormalizeDispatchGroupSize()`. In the shader we should always use -// `logical_workgroup_id` and `logical_global_id` instead of `workgroup_id` and `global_id`. -void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) { - shader.MainFunctionBody() - << " let logical_workgroup_id_z = workgroup_idx / (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y);\n" - << " let logical_workgroup_id_y = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) / uniforms.logical_dispatch_x;\n" - << " let logical_workgroup_id_x = (workgroup_idx % (uniforms.logical_dispatch_x * uniforms.logical_dispatch_y)) % uniforms.logical_dispatch_x;\n" - << " let logical_workgroup_id = vec3u(logical_workgroup_id_x, logical_workgroup_id_y, logical_workgroup_id_z);\n" - << " const workgroupSize = vec3u(workgroup_size_x, workgroup_size_y, workgroup_size_z);\n" - << " let logical_global_id = logical_workgroup_id * workgroupSize + local_id;\n"; -} - } // namespace void MatMulReadFnSource(ShaderHelper& shader, @@ -288,22 +274,20 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << "const innerElementSize = " << inner_elements_size << ";\n" << "const tileInner = " << tile_inner << ";\n"; - InitializeLogicalWorkgroupIDAndGlobalID(shader); - shader.MainFunctionBody() << " let localRow = i32(local_id.y);\n" << " let tileRow = localRow * rowPerThread;\n" << " let tileCol = i32(local_id.x);\n" - << " let globalRow = i32(logical_global_id.y) * rowPerThread;\n" - << " let globalCol = i32(logical_global_id.x);\n" - << " let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" - << " let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" + << " let globalRow = i32(global_id.y) * rowPerThread;\n" + << " let globalCol = i32(global_id.x);\n" + << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" + << " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" << " var acc: array, rowPerThread>;\n"; if (split_k) { // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from - // `kSplitK * i32(logical_global_id.z)`. + // `kSplitK * i32(global_id.z)`. // // For example: considering computing Y = (X * W + B) in one workgroup. // Let kSplitK = 2, B = [d1, d2] @@ -321,15 +305,15 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) // Workgroup3: compute (C1 * C2) // In each workgroup: - // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `logical_global_id.z` + // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z` // - When the computation in each workgroup is completed, add the result to Y with several // atomic built-in functions in `HandleMatMulWithSplitK()`. shader.MainFunctionBody() << "const kSplitK = " << split_dim_inner << ";\n" << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" - << " var kStart = kSplitK * i32(logical_global_id.z);\n" + << " var kStart = kSplitK * i32(global_id.z);\n" - // When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate + // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate // the index of split-k instead of batch. << " let batch = 0;\n" << " let batchIndices = 0u;\n"; @@ -337,7 +321,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, shader.MainFunctionBody() << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" - << " let batch = i32(logical_global_id.z);\n" + << " let batch = i32(global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : ""); } @@ -514,9 +498,7 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, << "const colPerThread = " << elements_per_thread_x << ";\n" << "const tileInner = " << tile_inner << ";\n"; - InitializeLogicalWorkgroupIDAndGlobalID(shader); - - shader.MainFunctionBody() << " let batch = i32(logical_global_id.z);\n" + shader.MainFunctionBody() << " let batch = i32(global_id.z);\n" << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" @@ -525,10 +507,10 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, shader.MainFunctionBody() << "let tileRow = i32(local_id.y) * rowPerThread;\n" << "let tileCol = i32(local_id.x) * colPerThread;\n" - << "let globalRow = i32(logical_global_id.y) * rowPerThread;\n" - << "let globalCol = i32(logical_global_id.x) * colPerThread;\n" - << "let globalRowStart = i32(logical_workgroup_id.y) * " << tile_a_outer << ";\n" - << "let globalColStart = i32(logical_workgroup_id.x) * " << tile_b_outer << ";\n" + << "let globalRow = i32(global_id.y) * rowPerThread;\n" + << "let globalCol = i32(global_id.x) * colPerThread;\n" + << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" + << "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 72dd235eb820a..55c2c5773cc1f 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -256,6 +256,8 @@ Status ComputeMatMul(ComputeContext* context, // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the // number of splits along `dim_inner`. + // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize + // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`. split_dim_inner = split_k_config.GetSplitDimInner(); dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; @@ -269,7 +271,7 @@ Status ComputeMatMul(ComputeContext* context, .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) - .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z) diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index dbd193bc38f58..143ba61c99e13 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -24,10 +24,7 @@ class MatMulProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_inner", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); + {"dim_inner", ProgramUniformVariableDataType::Uint32}); bool NeedSplitK() const; diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc index bf3bb53341418..2f34aa21c8309 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -64,7 +64,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { int components = input.NumComponents(); const std::string thread_max_decl = is_fp32_ - ? "var thread_max = x_value_t(-3.4028234663852886e+38f);\n" + ? "var thread_max = x_value_t(-3.402823e+38f);\n" : "var thread_max = x_value_t(-65504.0h);\n"; // Define shared memory for row max and row sum diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 4fff736fd2f32..77fa46cb87518 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -216,46 +216,6 @@ Status Conv::ComputeInternal(ComputeContext& context return context.RunProgram(conv2d_mm_program); } -template -Status Conv::PrePackInternal(ComputeContextBase& /* context */, - const Tensor& tensor, - int input_idx, - AllocatorPtr /* alloc */, - /*out*/ bool& is_packed) { - is_packed = false; - - if constexpr (is_channels_last) { - if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) { - // only deal with 4D NHWC weights - - // TODO: implement weight transpose for pre-pack here - // Conv::ComputeInternal() should be updated to reflect the change: - // - if the initializer is packed, `context.Input(1)` will be nullptr. - // - in this case, use `transposed_kernel_` instead. - - // // Step.1 - calculate transposed weight shape - // TensorShape transposed_kernel_shape{tensor.Shape()[2], - // tensor.Shape()[3], - // tensor.Shape()[1], - // tensor.Shape()[0]}; - - // // Step.2 - create transposed weight tensor - // transposed_kernel_ = std::make_unique(tensor.DataType(), transposed_kernel_shape, alloc); - - // // Step.3 - do transpose - // size_t perm[] = {2, 3, 1, 0}; - // ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, - // perm, - // tensor, - // *transposed_kernel_)); - - // is_packed = true; // set this flag to true so that ORT will release the initializer tensor - } - } - - return Status::OK(); -} - // Explicit template instantiation for FusedConv template class Conv; template class Conv; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h index 5bf94a459a44a..cafaa272c0613 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.h +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -23,16 +23,9 @@ class Conv : public WebGpuKernel { } Status ComputeInternal(ComputeContext& context) const override; - Status PrePackInternal(ComputeContextBase& context, - const Tensor& tensor, - int input_idx, - AllocatorPtr alloc, - /*out*/ bool& is_packed) override; - protected: ConvAttributes conv_attrs_; Activation activation_; - std::unique_ptr transposed_kernel_; // should only have value when `is_initializer` AND `is_4D` AND `is_NHWC` }; Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc index c66f2cbd582d9..2d5424c52a3f2 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc @@ -226,10 +226,7 @@ Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::v {static_cast(dim_inner)}, {pads}, {strides}, - {dilations}, - {dispatch[0]}, - {dispatch[1]}, - {dispatch[2]}}); + {dilations}}); return program; } diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h index e161bffb0c503..d7cc08aae26f3 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h @@ -38,10 +38,7 @@ class Conv2dMMProgram final : public Program { {"dim_inner", ProgramUniformVariableDataType::Uint32}, {"pads", ProgramUniformVariableDataType::Uint32}, {"strides", ProgramUniformVariableDataType::Uint32}, - {"dilations", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_x", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_y", ProgramUniformVariableDataType::Uint32}, - {"logical_dispatch_z", ProgramUniformVariableDataType::Uint32}); + {"dilations", ProgramUniformVariableDataType::Uint32}); private: const Activation& activation_; diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 5f59fecc425e2..7e8b434431781 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -92,28 +92,14 @@ Status SliceProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -static std::vector getInt64Input(const Tensor* tensor) { - if (tensor->IsDataType()) { - return std::vector(tensor->DataAsSpan().begin(), tensor->DataAsSpan().end()); - } - ORT_ENFORCE(tensor->IsDataType(), "Expected tensor of type int32 or int64"); - std::vector result; - auto span = tensor->DataAsSpan(); - result.reserve(span.size()); - for (auto v : span) { - result.push_back(static_cast(v)); - } - return result; -} - Status Slice::ComputeInternal(ComputeContext& context) const { // READ INPUTS const Tensor* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); auto input_rank = input_shape.NumDimensions(); - auto starts_raw = attr_starts_.empty() ? getInt64Input(context.Input(1)) : attr_starts_; - auto ends_raw = attr_ends_.empty() ? getInt64Input(context.Input(2)) : attr_ends_; + auto starts_raw = attr_starts_.empty() ? context.Input(1)->DataAsSpan() : gsl::make_span(attr_starts_); + auto ends_raw = attr_ends_.empty() ? context.Input(2)->DataAsSpan() : gsl::make_span(attr_ends_); ORT_ENFORCE(starts_raw.size() == ends_raw.size(), "starts and ends must have the same size"); @@ -140,7 +126,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { axes_default.push_back(i); } } - auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? axes_default : getInt64Input(axes_tensor)) : attr_axes_; + auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? gsl::make_span(axes_default) : axes_tensor->DataAsSpan()) : gsl::make_span(attr_axes_); std::vector steps_default; if (steps_tensor == nullptr) { @@ -149,7 +135,7 @@ Status Slice::ComputeInternal(ComputeContext& context) const { steps_default.push_back(1); } } - auto steps_raw = steps_tensor == nullptr ? steps_default : getInt64Input(steps_tensor); + auto steps_raw = steps_tensor == nullptr ? gsl::make_span(steps_default) : steps_tensor->DataAsSpan(); // get final axes std::vector axes, axes_fixed; diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 5415d4a5ead5b..cec321d0da80e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -108,7 +108,7 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, +Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output) { const auto& input_shape = input.Shape(); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index 5e9ccc6750cd6..b62a419fa12bc 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -16,7 +16,7 @@ class Transpose final : public WebGpuKernel, public TransposeBase { Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { } Status ComputeInternal(ComputeContext& context) const override; - static Status DoTranspose(onnxruntime::webgpu::ComputeContextBase& context, gsl::span permutations, const Tensor& input, Tensor& output); + static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, gsl::span permutations, const Tensor& input, Tensor& output); constexpr static uint32_t TILE_SIZE = 16; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index b8d5adc421124..28decb076951e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -147,9 +147,6 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // create program manager program_mgr_ = std::make_unique(*this); - // create split-k config - split_k_config_ = std::make_unique(adapter_info_); - // set query type #if !defined(__wasm__) if (DeviceHasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { @@ -181,7 +178,7 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } -Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) { +Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); @@ -291,8 +288,8 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra auto key = CalculateProgramCacheKey(program, inputs_segments, outputs_segments, is_1d_dispatch); if (is_profiling_) { - PendingKernelInfo pending_kernel_info(context.NodeName(), - context.OpType(), + PendingKernelInfo pending_kernel_info(context.KernelContext().GetNodeName(), + context.KernelContext().GetOpType(), program.Name(), key, inputs, @@ -445,7 +442,7 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; WGPUBuffer uniform_buffer = nullptr; - const webgpu::BufferManager& buffer_mgr = ComputeContextBase::BufferManagerAccessor::Get(context); + const webgpu::BufferManager& buffer_mgr = ComputeContext::BufferManagerAccessor::Get(context); if (uniform_buffer_total_size > 0) { std::vector uniform_data_buffer(uniform_buffer_total_size); @@ -913,6 +910,13 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; std::once_flag WebGpuContextFactory::init_default_flag_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 84dfb47ef4687..bd7dae75f2e2d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,6 +5,7 @@ #include #include +#include #include "core/providers/webgpu/webgpu_external_header.h" @@ -22,7 +23,7 @@ class Tensor; namespace webgpu { class WebGpuContext; -class ComputeContextBase; +class ComputeContext; class ProgramBase; // Definition for CapturedCommandInfo in the webgpu namespace @@ -151,13 +152,6 @@ class WebGpuContext final { return validation_mode_; } - // - // Get Split-K configuration. - // - const SplitKConfig& GetSplitKConfig() const { - return *split_k_config_; - } - void StartProfiling(); void CollectProfilingData(profiling::Events& events); void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); @@ -176,9 +170,16 @@ class WebGpuContext final { // Status PopErrorScope(); - Status Run(ComputeContextBase& context, const ProgramBase& program); + Status Run(ComputeContext& context, const ProgramBase& program); void OnRunEnd(); + // + // Get Split-K configuration. + // + // `split_k_config_` won't be initialized until the first call to this method. + // + const SplitKConfig& GetSplitKConfig(); + private: enum class TimestampQueryType { None = 0, @@ -276,7 +277,7 @@ class WebGpuContext final { uint32_t num_pending_dispatches_ = 0; const uint32_t max_num_pending_dispatches_ = 16; - std::unique_ptr split_k_config_; + std::optional split_k_config_; // profiling TimestampQueryType query_type_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 6b764d51bcf75..e0b84fef51f1f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -794,7 +794,8 @@ using namespace webgpu; WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, WebGpuContext& context, WebGpuExecutionProviderConfig&& config) - : IExecutionProvider{kWebGpuExecutionProvider, WebGpuDevice}, + : IExecutionProvider{kWebGpuExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)}, context_id_{context_id}, context_{context}, preferred_data_layout_{config.data_layout}, @@ -934,14 +935,13 @@ std::unique_ptr WebGpuExecutionProvider::GetEx std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, std::string_view node_op_type, DataLayout target_data_layout) const { - // NHWC for Resize operator is not implemented on kWebGpuExecutionProvider - if (node_domain == kOnnxDomain && node_op_type == "Resize") { - return target_data_layout != DataLayout::NHWC; + if (target_data_layout != DataLayout::NHWC) { + return std::nullopt; } - // WebGPU perfer NCHW for InstanceNormalization due to a better performance - if (node_domain == kOnnxDomain && node_op_type == "InstanceNormalization") { - return target_data_layout != DataLayout::NHWC; + // NHWC for Resize operator is not implemented on kWebGpuExecutionProvider + if (node_domain == kOnnxDomain && node_op_type == "Resize") { + return false; } return std::nullopt; diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc index ea38e9415e1fe..8d6ae6caeaf83 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc @@ -11,58 +11,25 @@ namespace webgpu { WebGpuKernel::WebGpuKernel(const OpKernelInfo& info) : OpKernel(info), - ep_(*static_cast(info.GetExecutionProvider())), - webgpu_context_(WebGpuContextFactory::GetContext(ep_.GetDeviceId())) { + ep_(*static_cast(info.GetExecutionProvider())) { } Status WebGpuKernel::Compute(OpKernelContext* p_op_kernel_context) const { - ComputeContext context{webgpu_context_, - ep_, - *this, - *p_op_kernel_context}; + WebGpuContext& webgpu_context = WebGpuContextFactory::GetContext(ep_.GetDeviceId()); + ComputeContext context{*p_op_kernel_context, *this, ep_, webgpu_context}; - if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { - webgpu_context_.PushErrorScope(); + if (webgpu_context.ValidationMode() >= ValidationMode::Full) { + webgpu_context.PushErrorScope(); } Status s = ComputeInternal(context); - if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { - ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); + if (webgpu_context.ValidationMode() >= ValidationMode::Full) { + ORT_RETURN_IF_ERROR(webgpu_context.PopErrorScope()); } return s; } -Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /* prepacked_weights */) { - ComputeContextBase context{webgpu_context_, ep_, *this}; - - if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { - webgpu_context_.PushErrorScope(); - } - - // Currently, ORT does not allow using prepacked weights in non-CPU EPs. - // So we do not pass prepacked_weights to PrePackInternal. - // Kernel implementation that supports prepacking should manage its own storage. - - Status s = PrePackInternal(context, tensor, input_idx, alloc, is_packed); - - if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { - ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); - } - - return s; -} - -Status WebGpuKernel::PrePackInternal(ComputeContextBase& /*context*/, - const Tensor& /*tensor*/, - int /*input_idx*/, - AllocatorPtr /*alloc*/, - /*out*/ bool& is_packed) { - is_packed = false; - return Status::OK(); -} - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 2c57991c6ee35..3c750e305421c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -23,41 +23,8 @@ class WebGpuKernel : public OpKernel { virtual Status ComputeInternal(ComputeContext& context) const = 0; - // Overrides OpKernel::PrePack to handle constant tensor pre-processing for WebGPU kernels. - // This method creates a ComputeContextBase and delegates to PrePackInternal. - // - // NOTE: Currently, ORT does not allow using prepacked weights in non-CPU EPs, so the - // prepacked_weights parameter is not passed to PrePackInternal. Kernel implementations - // that support prepacking should manage their own storage. - Status PrePack(const Tensor& tensor, - int input_idx, - AllocatorPtr alloc, - /*out*/ bool& is_packed, - /*out*/ PrePackedWeights* prepacked_weights) override; - - // Virtual method that allows derived kernels to pre-process constant tensors during initialization. - // - // This method is called during kernel initialization when constant tensors are available, - // allowing kernels to perform operations like tensor transposition or format conversion - // before the first Compute call. - // - // @param context The WebGPU compute context base providing access to the execution environment. - // @param tensor The constant tensor to potentially pre-process. - // @param input_idx The index of this input in the kernel's input list. - // @param alloc The allocator to use for any new tensor allocations. - // @param is_packed Output parameter. Set to true if the tensor was pre-packed/processed, - // false otherwise. The default implementation sets this to false. - // - // @return Status::OK() on success, or an error status on failure. - virtual Status PrePackInternal(ComputeContextBase& context, - const Tensor& tensor, - int input_idx, - AllocatorPtr alloc, - /*out*/ bool& is_packed); - private: const WebGpuExecutionProvider& ep_; - WebGpuContext& webgpu_context_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 5fd24b2bff037..568d29a96cb88 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -21,24 +21,27 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components return TensorShape(shape_vector); } -SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) { +SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) { + SplitKConfig config = {}; + if (adapter_info.vendor == std::string_view{"intel"}) { if (adapter_info.architecture == std::string_view{"xe-2lpg"} || adapter_info.architecture == std::string_view{"xe-2hpg"} || adapter_info.architecture == std::string_view{"xe-lpg"} || adapter_info.architecture == std::string_view{"gen-12hp"}) { - enable_split_k_ = true; + config.enable_split_k_ = true; // Below thresholds are only verified on the above Intel GPUs without any regressions. The // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more // atomic calls for each output value. - split_dim_inner_ = 256; - min_dim_inner_with_split_k_ = split_dim_inner_ * 2; - max_dim_inner_with_split_k_ = split_dim_inner_ * 9; - max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; + config.split_dim_inner_ = 256; + config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; + config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; + config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; } } + return config; } bool SplitKConfig::UseSplitK( diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 7d5ab5fea8006..d45b9bf4dd119 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -91,12 +91,9 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; } -/** - * Configuration for Split-K optimization (Conv|MatMul). - */ class SplitKConfig { public: - explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); + static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index ab3932e7abfb4..4d4dea9cb444c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2943,8 +2943,6 @@ Status InferenceSession::Run(const RunOptions& run_options, << cached_execution_provider_for_graph_replay_.Type() << " CUDA Graph for this model with tag: " << run_options.run_tag << " with graph annotation id: " << graph_annotation_id; - // log evaluation start to trace logging provider - env.GetTelemetryProvider().LogEvaluationStart(session_id_); ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id)); } else { InlinedVector exec_providers_to_stop; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 4cb21b80109c8..6189e6ca7f012 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -404,7 +404,6 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model session))); } - Env::Default().GetTelemetryProvider().LogCompileModel(session->GetCurrentSessionId()); ORT_RETURN_IF_ERROR(ToStatusAndRelease(InitializeSession(session_options, *session))); return Status::OK(); } diff --git a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc index 5deef01cd783e..70c7a5b2bcdcb 100644 --- a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc +++ b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc @@ -22,17 +22,10 @@ namespace test { // --------- Helpers --------- -// cuda errors are sticky and may affect subsequent API calls. -// we want to clear the error if when supported check fails. -void ClearCudaError() { - ORT_IGNORE_RETURN_VALUE(::cudaGetLastError()); -} - static bool IsCudaMemPoolSupported() { int ort_cuda_rt_version = 0; cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version); if (cuda_status != cudaSuccess) { - ClearCudaError(); return false; } @@ -43,7 +36,6 @@ static bool IsCudaMemPoolSupported() { int ort_cuda_driver_version = 0; cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version); if (cuda_status != cudaSuccess) { - ClearCudaError(); return false; } @@ -73,10 +65,9 @@ static bool IsCudaMemPoolSupported() { cudaMemPool_t pool; auto cuda_error = cudaMemPoolCreate(&pool, &props); if (cuda_error != cudaSuccess) { - ClearCudaError(); return false; } - ORT_IGNORE_RETURN_VALUE(cudaMemPoolDestroy(pool)); + cuda_error = cudaMemPoolDestroy(pool); return true; } @@ -89,9 +80,7 @@ static ::cudaStream_t NewCudaStream() { } static void DestroyCudaStream(::cudaStream_t s) { - if (s) { - EXPECT_EQ(cudaSuccess, ::cudaStreamDestroy(s)); - } + if (s) (void)::cudaStreamDestroy(s); } static void TouchDevice(void* p, size_t bytes, ::cudaStream_t s, unsigned char value = 0xAB) { diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index af9706855ee3c..d8cc56d738175 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -203,48 +203,6 @@ TEST_P(TypeTests, IOTypes) { } } -TEST(NvExecutionProviderTest, TestSessionOutputs) { - /* - * Model #1: - * - * "input" ---> TopK --- - * |---> "scores" - * |--- Less ---> "Less_output_0" - * |--- Div ---> "Div_output_0" - * |--- Mod ---> "labels" - */ - { - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - - auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); - Ort::Session session(*ort_env, model_path, session_options); - - size_t output_count = session.GetOutputCount(); - ASSERT_TRUE(output_count == 4); - } - - /* - * Model #2: - * - * "X" ---> Dropout ---> MatMul ---> "Y" - * ^ | - * | | - * "W" ------ ----> Can't be graph's output - * - */ - { - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - - auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); - Ort::Session session(*ort_env, model_path, session_options); - - size_t output_count = session.GetOutputCount(); - ASSERT_TRUE(output_count == 1); - } -} - INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, ::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, diff --git a/onnxruntime/test/providers/qnn/README.md b/onnxruntime/test/providers/qnn/README.md deleted file mode 100644 index c3d0c720a1aa4..0000000000000 --- a/onnxruntime/test/providers/qnn/README.md +++ /dev/null @@ -1,70 +0,0 @@ -# ONNX Runtime QNN Execution Provider Tests -## Overview -1. The `onnxruntime/test/providers/qnn` directory contains integration tests for the Qualcomm Neural Network (QNN) execution provider. -2. Most testcases run an ONNX model through the QNN-EP, then verifies the inference result against the one on CPU-EP - -## Building the Tests -The tests are built as part of the regular ONNX Runtime build. After a successful build you will have an executable named -- onnxruntime_provider_test.exe (Windows) -- onnxruntime_provider_test (Linux/macOS) - -## Running the Tests -1. QNN supports several backends. You can use the standard Google‑Test syntax for filtering: - - `onnxruntime_provider_test.exe --gtest_filter=QnnCPUBackendTests.*` - - `onnxruntime_provider_test.exe --gtest_filter=QnnHTPBackendTests.*` - - `onnxruntime_provider_test.exe --gtest_filter=QnnGPUBackendTests.*` - - `onnxruntime_provider_test.exe --gtest_filter=QnnIRBackendTests.*` -2. Saving Test Artifacts - - For debugging it is often helpful to keep the intermediate files that the tests generate. The following environment - variables are recognized by the test binary: - - `QNN_DUMP_ONNX`: Saves the input ONNX model used for the test - - `QNN_DUMP_JSON`: Save json qnn graph with provider_option `dump_json_qnn_graph` - - `QNN_DUMP_DLC`: Saves the compiled QNN DLC file by specifying the provider_option `backend_path` to `QnnIr.dll` - - The artifacts will be saved to a directory named with `_` - ``` - . - ├── QnnCPUBackendTests_BatchNorm2D_fp32 # RunQnnModelTest - │ ├── dumped_f32_model.onnx # float32 ONNX model - │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc - │ └── QNNExecutionProvider_QNN_XXXX_X_X.json - ├── QnnHTPBackendTests_BatchNorm_FP16 # TestFp16ModelAccuracy - │ ├── dumped_f16_model.onnx # float16 ONNX model - │ ├── dumped_f32_model.onnx # float32 ONNX model - │ ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc - │ └── QNNExecutionProvider_QNN_XXXX_X_X.json - └── QnnHTPBackendTests_BatchNorm2D_U8U8S32 # TestQDQModelAccuracy - ├── dumped_f32_model.onnx # float32 ONNX model - ├── dumped_qdq_model.onnx # QDQ ONNX model - ├── QNNExecutionProvider_QNN_XXXX_X_X.dlc - └── QNNExecutionProvider_QNN_XXXX_X_X.json - - # All artifact files are placed under the current working directory from which the test binary is invoked. - ``` -3. Verbose - - `QNN_VERBOSE`: Sets the ONNX Runtime log level to `ORT_LOGGING_LEVEL_VERBOSE` - -4. You can enable any combination of these environment variables, for example: - - On Linux/macOS - ```bash - export QNN_DUMP_ONNX=1 - export QNN_DUMP_JSON=1 - export QNN_DUMP_DLC=1 - export QNN_VERBOSE=1 - ``` - - On Windows - ```cmd - set QNN_DUMP_ONNX=1 - set QNN_DUMP_JSON=1 - set QNN_DUMP_DLC=1 - set QNN_VERBOSE=1 - ``` - ```ps1 - $Env:QNN_DUMP_ONNX = "1" - $Env:QNN_DUMP_JSON = "1" - $Env:QNN_DUMP_DLC = "1" - $Env:QNN_VERBOSE = "1" - ``` - -# Note -- An issue on QNN backends can prevent the test artifacts from being successfully saved. -- The `onnxruntime_provider_test.exe` does not automatically delete the artifact directories, so you may want to prune them after a debugging session. diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 15a9132aaa16c..1c70f4012090e 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -101,12 +101,6 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, std::function* ep_graph_checker) { - std::filesystem::path output_dir; - if (QNNTestEnvironment::GetInstance().dump_onnx() || - QNNTestEnvironment::GetInstance().dump_json() || - QNNTestEnvironment::GetInstance().dump_dlc()) { - output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); - } EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; verification_params.fp32_abs_err = fp32_abs_err; @@ -116,10 +110,6 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); - if (QNNTestEnvironment::GetInstance().verbose()) { - logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); - } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -133,27 +123,7 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); - - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); - } - TryEnableQNNSaver(provider_options); - if (QNNTestEnvironment::GetInstance().dump_dlc()) { - provider_options["dump_qnn_ir_dlc"] = "1"; - provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); -#if defined(_WIN32) - provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; -#else - provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; -#endif // defined(_WIN32) - } - if (QNNTestEnvironment::GetInstance().dump_json()) { - provider_options["dump_json_qnn_graph"] = "1"; - provider_options["json_qnn_graph_dir"] = output_dir.string(); - } RunAndVerifyOutputsWithEP(AsByteSpan(model_data.data(), model_data.size()), "QNN_EP_TestLogID", QnnExecutionProviderWithOptions(provider_options), helper.feeds_, verification_params, @@ -164,21 +134,11 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, logging::Severity log_severity, std::function* ep_graph_checker) { - std::filesystem::path output_dir; - if (QNNTestEnvironment::GetInstance().dump_onnx() || - QNNTestEnvironment::GetInstance().dump_dlc() || - QNNTestEnvironment::GetInstance().dump_json()) { - output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); - } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); - if (QNNTestEnvironment::GetInstance().verbose()) { - logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); - } onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, @@ -192,27 +152,7 @@ void RunQnnModelTestHTPNoVerify(const GetTestModelFn& build_test_case, ProviderO // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); - - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(model, dump_path)); - } - TryEnableQNNSaver(provider_options); - if (QNNTestEnvironment::GetInstance().dump_dlc()) { - provider_options["dump_qnn_ir_dlc"] = "1"; - provider_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); -#if defined(_WIN32) - provider_options["qnn_ir_backend_path"] = "QnnIr.dll"; -#else - provider_options["qnn_ir_backend_path"] = "libQnnIr.so"; -#endif // defined(_WIN32) - } - if (QNNTestEnvironment::GetInstance().dump_json()) { - provider_options["dump_json_qnn_graph"] = "1"; - provider_options["json_qnn_graph_dir"] = output_dir.string(); - } SessionOptions so; so.session_logid = "QNN_EP_TestLogID"; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 4d4f795d161b1..aeb3a9a114871 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -499,77 +499,6 @@ struct QDQTolerance { float value; }; -class QNNTestEnvironment { - public: - // Delete copy constructor and assignment operator - QNNTestEnvironment(const QNNTestEnvironment&) = delete; - QNNTestEnvironment& operator=(const QNNTestEnvironment&) = delete; - - // Static method to get the singleton instance - static QNNTestEnvironment& GetInstance() { - static QNNTestEnvironment instance; - return instance; - } - - bool dump_onnx() const { return dump_onnx_; } - bool dump_json() const { return dump_json_; } - bool dump_dlc() const { return dump_dlc_; } - bool verbose() const { return verbose_; } - - std::filesystem::path CreateTestcaseDirs() { - std::string test_suite_name = ::testing::UnitTest::GetInstance()->current_test_info()->test_suite_name(); - std::string test_name = ::testing::UnitTest::GetInstance()->current_test_info()->name(); - std::filesystem::path output_dir = std::filesystem::current_path() / (test_suite_name + "_" + test_name); - std::filesystem::create_directories(output_dir); - - return output_dir; - } - - private: - // Private constructor for singleton - QNNTestEnvironment() { - ParseEnvironmentVars(); - } - - // Helper function to check if an environment variable is set - bool IsEnvVarSet(const char* name) { - const char* value = std::getenv(name); - if (value == nullptr) { - return false; - } - - // Consider the variable set if it's not empty and not "0" - return *value != '\0' && *value != '0'; - } - - void ParseEnvironmentVars() { - if (IsEnvVarSet("QNN_DUMP_ONNX")) { - std::cout << "[QNN only] ONNX model dumping enabled via environment variable." << std::endl; - dump_onnx_ = true; - } - - if (IsEnvVarSet("QNN_DUMP_JSON")) { - std::cout << "[QNN only] Json QNN Graph dumping enabled via environment variable." << std::endl; - dump_json_ = true; - } - - if (IsEnvVarSet("QNN_DUMP_DLC")) { - std::cout << "[QNN only] DLC dumping enabled via environment variable." << std::endl; - dump_dlc_ = true; - } - - if (IsEnvVarSet("QNN_VERBOSE")) { - std::cout << "Verbose enabled via environment variable." << std::endl; - verbose_ = true; - } - } - - bool dump_onnx_ = false; - bool dump_json_ = false; - bool dump_dlc_ = false; - bool verbose_ = false; -}; - /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: * @@ -600,21 +529,15 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}, std::function* qnn_ep_graph_checker = nullptr) { - std::filesystem::path output_dir; - if (QNNTestEnvironment::GetInstance().dump_onnx() || - QNNTestEnvironment::GetInstance().dump_dlc() || - QNNTestEnvironment::GetInstance().dump_json()) { - output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); - } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); + + // Uncomment to dump LOGGER() output to stdout. + // logging_manager.RemoveSink(logging::SinkType::EtwSink); + logging_manager.SetDefaultLoggerSeverity(log_severity); - if (QNNTestEnvironment::GetInstance().verbose()) { - logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); - } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -628,11 +551,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); - } + // Uncomment to save f32 model to disk for debugging. + // ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, ToPathString("cmp_accuracy.f32.onnx"))); // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; @@ -674,27 +594,11 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe ASSERT_STATUS_OK(qdq_model.MainGraph().Resolve()); qdq_model.ToProto().SerializeToString(&qdq_model_data); - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_qdq_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx QDQ model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, dump_path)); - } + // Uncomment to save QDQ model to disk for debugging. + // ASSERT_STATUS_OK(onnxruntime::Model::Save(qdq_model, ToPathString("cmp_accuracy.qdq.onnx"))); bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); - if (QNNTestEnvironment::GetInstance().dump_dlc()) { - qnn_options["dump_qnn_ir_dlc"] = "1"; - qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); -#if defined(_WIN32) - qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; -#else - qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; -#endif // defined(_WIN32) - } - if (QNNTestEnvironment::GetInstance().dump_json()) { - qnn_options["dump_json_qnn_graph"] = "1"; - qnn_options["json_qnn_graph_dir"] = output_dir.string(); - } std::vector qnn_qdq_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; @@ -839,21 +743,11 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, logging::Severity log_severity = logging::Severity::kERROR, const std::string& qnn_ctx_model_path = "", const std::unordered_map& session_option_pairs = {}) { - std::filesystem::path output_dir; - if (QNNTestEnvironment::GetInstance().dump_onnx() || - QNNTestEnvironment::GetInstance().dump_dlc() || - QNNTestEnvironment::GetInstance().dump_json()) { - output_dir = QNNTestEnvironment::GetInstance().CreateTestcaseDirs(); - } // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(log_severity); - if (QNNTestEnvironment::GetInstance().verbose()) { - logging_manager.RemoveSink(logging::SinkType::EtwSink); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); - } // Create float model and serialize it to a string. onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(), @@ -866,12 +760,6 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f32_model.MainGraph().Resolve()); f32_model.ToProto().SerializeToString(&f32_model_data); - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_f32_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float32 model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(f32_model, dump_path)); - } - // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; InferenceModel(f32_model_data, "f32_model_logger", {}, ExpectedEPNodeAssignment::All, @@ -908,27 +796,8 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, ASSERT_STATUS_OK(f16_model.MainGraph().Resolve()); f16_model.ToProto().SerializeToString(&f16_model_data); - if (QNNTestEnvironment::GetInstance().dump_onnx()) { - auto dump_path = output_dir / ToPathString("dumped_f16_model.onnx"); - LOGS(logging_manager.DefaultLogger(), VERBOSE) << "Save onnx float16 model at: " << dump_path; - ASSERT_STATUS_OK(onnxruntime::Model::Save(f16_model, dump_path)); - } - bool is_qnn_ep = true; TryEnableQNNSaver(qnn_options); - if (QNNTestEnvironment::GetInstance().dump_dlc()) { - qnn_options["dump_qnn_ir_dlc"] = "1"; - qnn_options["dump_qnn_ir_dlc_dir"] = output_dir.string(); -#if defined(_WIN32) - qnn_options["qnn_ir_backend_path"] = "QnnIr.dll"; -#else - qnn_options["qnn_ir_backend_path"] = "libQnnIr.so"; -#endif // defined(_WIN32) - } - if (QNNTestEnvironment::GetInstance().dump_json()) { - qnn_options["dump_json_qnn_graph"] = "1"; - qnn_options["json_qnn_graph_dir"] = output_dir.string(); - } std::vector qnn_f16_outputs; if (!qnn_ctx_model_path.empty()) { onnx::ModelProto model_proto; diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index dce0d570ec238..6a6545c68cb4f 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -1,6 +1,5 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "onnxruntime_cxx_api.h" #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" @@ -19,8 +18,6 @@ using namespace std; using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; -extern std::unique_ptr ort_env; - namespace onnxruntime { namespace test { @@ -1363,49 +1360,5 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); } - -TEST(TensorrtExecutionProviderTest, TestSessionOutputs) { - /* - * Model #1: - * - * "input" ---> TopK --- - * |---> "scores" - * |--- Less ---> "Less_output_0" - * |--- Div ---> "Div_output_0" - * |--- Mod ---> "labels" - */ - { - OrtTensorRTProviderOptionsV2 provider_options; - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider_TensorRT_V2(provider_options); - - auto model_path = ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx"); - Ort::Session session(*ort_env, model_path, session_options); - - size_t output_count = session.GetOutputCount(); - ASSERT_TRUE(output_count == 4); - } - - /* - * Model #2: - * - * "X" ---> Dropout ---> MatMul ---> "Y" - * ^ | - * | | - * "W" ------ ----> Can't be graph's output - * - */ - { - OrtTensorRTProviderOptionsV2 provider_options; - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider_TensorRT_V2(provider_options); - - auto model_path = ORT_TSTR("testdata/node_output_not_used.onnx"); - Ort::Session session(*ort_env, model_path, session_options); - - size_t output_count = session.GetOutputCount(); - ASSERT_TRUE(output_count == 1); - } -} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/node_output_not_used.onnx b/onnxruntime/test/testdata/node_output_not_used.onnx deleted file mode 100644 index e2726182fddc2c265752e46346735c26e33add4b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 189 zcmd=lo3kgAWK)CKji3J%^!XPX8xOg}ig*dpFIGBN$2_zVfB*+Ak RNCFB*q6<2)a4`t*0ss-ID|-L{ diff --git a/onnxruntime/test/testdata/node_output_not_used.py b/onnxruntime/test/testdata/node_output_not_used.py deleted file mode 100644 index d36d5e9cfd2f8..0000000000000 --- a/onnxruntime/test/testdata/node_output_not_used.py +++ /dev/null @@ -1,43 +0,0 @@ -import onnx -from onnx import TensorProto, helper - - -def create_model_with_node_output_not_used(model_path): - # Create graph - x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) - w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 3]) - y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) - - # Dropout node (two outputs) - dropout_node = helper.make_node( - "Dropout", - inputs=["X"], - outputs=["dropout_out", "dropout_mask"], - name="DropoutNode", - ) - - # MatMul node - matmul_node = helper.make_node( - "MatMul", - inputs=["dropout_out", "W"], - outputs=["Y"], - name="MatMulNode", - ) - - graph = helper.make_graph( - nodes=[dropout_node, matmul_node], - name="DropoutMatMulGraph", - inputs=[x, w], - outputs=[y], - ) - - model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 13)]) - - onnx.checker.check_model(model) - onnx.save(model, model_path) - - print(f"Model saved to: {model_path}") - - -if __name__ == "__main__": - create_model_with_node_output_not_used("node_output_not_used.onnx") diff --git a/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx b/onnxruntime/test/testdata/topk_and_multiple_graph_outputs.onnx deleted file mode 100644 index 340c3d420d5746844be0bd3769a174b4e69de801..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 393 zcmdW?8(U zIYJ{dP(TSp09}P@53)A4oW!KmoMI_v-~1FM5Fx|~a-n-sVnK!$HwU8tyA{(KCMQO3 zEp8x_k--V -B -V [-H ] " 1>&2; exit 1; } + +ROCM_HOME=/opt/rocm + +while getopts S:B:V:H:I:P: parameter_Option; do + case "${parameter_Option}" in + S) SOURCE_DIR=${OPTARG};; + B) BINARY_DIR=${OPTARG};; + V) ROCM_VERSION=${OPTARG};; + H) ROCM_HOME=${OPTARG};; + I) IMAGE=${OPTARG};; + P) PYTHON_BIN=${OPTARG};; + *) usage ;; + esac +done + +EXIT_CODE=1 + +docker run -e SYSTEM_COLLECTIONURI --rm \ + --security-opt seccomp=unconfined \ + --shm-size=1024m \ + --user $UID:$(id -g $USER) \ + -e NIGHTLY_BUILD \ + --volume $SOURCE_DIR:/onnxruntime_src \ + --volume $BINARY_DIR:/build \ + --volume /data/models:/build/models:ro \ + --volume /data/onnx:/data/onnx:ro \ + --workdir /onnxruntime_src \ + $IMAGE \ + /bin/bash -c "${PYTHON_BIN:-python} /onnxruntime_src/tools/ci_build/build.py --config Release --build_dir /build --parallel --use_rocm --use_binskim_compliant_compile_flags --rocm_version=$ROCM_VERSION --rocm_home $ROCM_HOME --nccl_home $ROCM_HOME --build_shared_lib --skip_submodule_sync --skip_tests --cmake_extra_defines FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER && cd /build/Release && make install DESTDIR=/build/installed" + + +EXIT_CODE=$? + +set -e +exit $EXIT_CODE diff --git a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh new file mode 100755 index 0000000000000..0be64d96f3a34 --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh @@ -0,0 +1,43 @@ +#!/bin/bash +set -e -x + +# version +ROCM_VERSION=6.2.3 + +while getopts "r:" parameter_Option +do case "${parameter_Option}" +in +r) ROCM_VERSION=${OPTARG};; +esac +done + +tee /etc/yum.repos.d/amdgpu.repo < Date: Tue, 9 Dec 2025 07:46:33 -0800 Subject: [PATCH 138/138] disable QDQ scale-strip for 2025.0 --- .../providers/openvino/backend_manager.cc | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 3426a2781bbc6..712f3c5faafbe 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -322,6 +322,7 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { return false; } +#if ((OPENVINO_VERSION_MAJOR < 2025) || ((OPENVINO_VERSION_MAJOR == 2025) && (OPENVINO_VERSION_MINOR < 0))) static bool Is16BitTensor(const onnxruntime::NodeArg* node_arg) { const auto* type_proto = node_arg ? node_arg->TypeAsProto() : nullptr; return type_proto && type_proto->has_tensor_type() && @@ -359,6 +360,7 @@ static bool IsQDQGraphWithUint16OrInt16(const onnxruntime::GraphViewer& graph_vi } return false; } +#endif static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, @@ -492,10 +494,6 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, } #endif - // Check if the graph is QDQ and has int16 or uint16 quantization - // If so, we will apply the QDQ scales fix transformation (for GPU device only) - bool is_qdq_graph_uint16_or_int16 = IsQDQGraphWithUint16OrInt16(subgraph); - const auto& onnx_model_path_name = subgraph.ModelPath(); // QDQ stripping enabled only for the NPU and experimentally on the GPU if ((session_context_.device_type.find("NPU") != std::string::npos) && @@ -508,8 +506,11 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; - } else if ((session_context_.device_type.find("GPU") != std::string::npos) && - is_qdq_graph_uint16_or_int16) { + } +#if ((OPENVINO_VERSION_MAJOR < 2025) || ((OPENVINO_VERSION_MAJOR == 2025) && (OPENVINO_VERSION_MINOR < 0))) + // Enable OVEP-level QDQ stripping only for OV versions that don't have it + else if ((session_context_.device_type.find("GPU") != std::string::npos) && + IsQDQGraphWithUint16OrInt16(subgraph)) { // Create a copy of the model std::unique_ptr model; Status status = qdq_scales_fix::Transform(subgraph, logger, model); @@ -519,7 +520,9 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; - } else { + } +#endif + else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled"; // scan ext initializers: @@ -845,4 +848,4 @@ void BackendManager::RewindKVCache(size_t index) { } } // namespace openvino_ep -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file