diff --git a/.github/workflows/minimal_test.yml b/.github/workflows/minimal_test.yml index f732001..e1c48bd 100644 --- a/.github/workflows/minimal_test.yml +++ b/.github/workflows/minimal_test.yml @@ -106,6 +106,7 @@ jobs: uses: actions/checkout@v4 - name: Checkout PR + shell: bash run: | if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then gh pr checkout ${{ github.event.inputs.pr_number }} @@ -195,6 +196,7 @@ jobs: uses: actions/checkout@v4 - name: Checkout PR + shell: bash run: | if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then gh pr checkout ${{ github.event.inputs.pr_number }} diff --git a/torchmeter/statistic.py b/torchmeter/statistic.py index 4802e9c..a0e3c6f 100644 --- a/torchmeter/statistic.py +++ b/torchmeter/statistic.py @@ -868,7 +868,7 @@ def __hook_func( cuda_sync() # WAIT FOR GPU SYNC it = start_event.elapsed_time(end_event) * 1e-3 # ms -> s # type: ignore - tp = ipt[0].shape[0] / it # TODO: batch infer + tp = 1 / it self.__InferTime.append(it) self.__Throughput.append(tp)